google-research

Форк
0
/
blur_test.py 
467 строк · 15.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Unit tests for BLUR."""
17
import numpy as np
18
import tensorflow.compat.v1 as tf
19

20
from blur import blur
21
from blur import blur_env
22
from blur import blur_meta
23
from blur import genome_util
24
from blur import synapse_util
25

26

27
def d_sigmoid(x):
28
  s = tf.math.sigmoid(x)
29
  return s * (1 - s)
30

31

32
def sigmoid_with_grad(x):
33
  return tf.stack([tf.math.sigmoid(x[Ellipsis, 0]), d_sigmoid(x[Ellipsis, 1])], axis=-1)
34

35

36
def random_dataset():
37
  n = 1000
38
  ds = tf.data.Dataset.from_tensor_slices({
39
      'support': (
40
          np.random.normal(0, 255,
41
                           size=(0, 1, 1, n,
42
                                 784)).astype(blur_env.NP_FLOATING_TYPE),
43
          np.random.randint(0, 10, size=(0, 1, 1, n,
44
                                         2)).astype(blur_env.NP_FLOATING_TYPE),
45
      )})
46
  return ds
47

48

49
def random_dense(num_in=10, num_out=5):
50
  inp = np.random.random((1, 1, 1, num_in, 1))
51
  out = np.random.random((1, 1, 1, num_out, 1))
52

53
  # Neuron: input + [1] + output.
54
  num_in_bias = num_in + 1
55
  ow = np.random.random((1, 1, num_in_bias, num_out))
56

57
  return inp, out, ow
58

59

60
def get_blur_state(env, inp, out, ow):
61
  pre = np.concatenate([inp, np.zeros_like(inp)],
62
                       axis=-1).astype(blur_env.NP_FLOATING_TYPE)
63
  post = np.concatenate([np.zeros_like(out), out],
64
                        axis=-1).astype(blur_env.NP_FLOATING_TYPE)
65
  ww = ow.astype(blur_env.NP_FLOATING_TYPE)
66
  ww = ww[Ellipsis, None]
67
  synapse = synapse_util.combine_in_out_synapses(
68
      ww, synapse_util.transpose_synapse(ww, env), env=env)
69
  synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2)
70

71
  genome = genome_util.convert_genome_to_tf_variables(
72
      genome_util.create_backprop_genome(num_species=1))
73

74
  network_spec = blur.NetworkSpec()
75
  network_spec.symmetric_synapses = True
76
  network_spec.batch_average = False
77
  network_spec.backward_update = 'multiplicative_second_state'
78

79
  return pre, post, synapse, genome, network_spec
80

81

82
def tf_gradients(inp, out, ow):
83
  tfinp = tf.constant(inp[Ellipsis, 0])
84

85
  # Append '1' to the end of the input
86
  bias = tf.constant([[[[1]]]], tfinp.dtype)
87
  inp_with_bias = tf.concat([tfinp, bias], axis=-1)
88
  tfw = tf.constant(ow)
89
  y = inp_with_bias @ tfw
90
  _, grad_weights, grad_image = tf.gradients(y, [bias, tfw, tfinp], out[Ellipsis, 0])
91
  return grad_weights, grad_image, y
92

93

94
def verify_equal(update1, update2, hebbian_update, grad_weights, grad_image, y,
95
                 num_in, num_out):
96
  # Hebbian update is [#in + #out, #in + #out, 2] matrix
97
  # and it should look like this.
98
  # Z   U
99
  # Z   Z
100
  #
101
  # Z   Z
102
  # U^T Z
103
  # Where Z is zeros
104
  np.testing.assert_allclose(
105
      hebbian_update[Ellipsis, 0],
106
      np.swapaxes(hebbian_update[Ellipsis, 1], -1, -2),
107
      rtol=1e-5)
108

109
  np.testing.assert_allclose(
110
      hebbian_update[Ellipsis, :num_in + 1, num_in + 1:, 0], grad_weights, rtol=1e-5)
111

112
  np.testing.assert_allclose(hebbian_update[Ellipsis, :num_in + 1, :num_in, 0],
113
                             np.zeros((1, 1, num_in + 1, 10)))
114
  np.testing.assert_allclose(hebbian_update[Ellipsis, num_in + 1:, num_in + 1:, 0],
115
                             np.zeros((1, 1, num_out, num_out)))
116

117
  np.testing.assert_allclose(update1[Ellipsis, 1], grad_image, rtol=1e-5)
118
  np.testing.assert_allclose(update1[Ellipsis, 0], np.zeros_like(update1[Ellipsis, 0]))
119

120
  np.testing.assert_allclose(update2[Ellipsis, 0], y, rtol=1e-5)
121

122

123
class BlurTest(tf.test.TestCase):
124

125
  def test_sync_in_out_synapses(self):
126
    num_in = 3
127
    num_out = 2
128
    num_states = 2
129
    env = blur_env.tf_env
130
    in_out_synapse = tf.random.normal(shape=(num_in + 1, num_out, num_states))
131
    out_in_synapse = tf.random.normal(shape=(num_out, num_in + 1, num_states))
132

133
    synapse = synapse_util.combine_in_out_synapses(in_out_synapse,
134
                                                   out_in_synapse, env)
135
    synapse_synced = synapse_util.sync_in_and_out_synapse(synapse, num_in, env)
136
    fwd_sync_submatrix = synapse_util.synapse_submatrix(
137
        synapse_synced,
138
        num_in,
139
        synapse_util.UpdateType.FORWARD,
140
        include_bias=True)
141

142
    bkw_sync_submatrix = synapse_util.synapse_submatrix(
143
        synapse_synced,
144
        num_in,
145
        synapse_util.UpdateType.BACKWARD,
146
        include_bias=True)
147

148
    with tf.Session() as s:
149
      bwd, fwd, inp = s.run([
150
          synapse_util.transpose_synapse(bkw_sync_submatrix, env),
151
          fwd_sync_submatrix, in_out_synapse
152
      ])
153
    self.assertAllEqual(fwd, inp)
154
    self.assertAllEqual(bwd, inp)
155

156
  def test_verify_gradient_match_tf(self):
157
    num_in = 10
158
    num_out = 15
159
    tf.reset_default_graph()
160
    tf.disable_v2_behavior()
161
    inp, out, ow = random_dense(num_in, num_out)
162
    env = blur_env.tf_env
163
    pre, post, synapse, _, network_spec = get_blur_state(env, inp, out, ow)
164

165
    genome = genome_util.create_backprop_genome(num_species=1)
166

167
    update1, update2 = blur.get_synaptic_update(
168
        pre,
169
        post,
170
        synapse=synapse,
171
        input_transform_gn=genome.neuron.transform,
172
        update_type=synapse_util.UpdateType.BOTH,
173
        env=env)
174

175
    hebbian_update = blur.get_hebbian_update(
176
        pre,
177
        post,
178
        genome.synapse.transform,
179
        global_spec=network_spec,
180
        env=env)
181

182
    grad_weights, grad_image, y = tf_gradients(inp, out, ow)
183

184
    np.set_printoptions(precision=4, linewidth=200)
185

186
    with tf.Session():
187
      verify_equal(update1.eval(), update2.eval(), hebbian_update.eval(),
188
                   grad_weights.eval(), grad_image.eval(), y.eval(), num_in,
189
                   num_out)
190

191
  def test_verify_gradient_match_jp(self):
192
    tf.reset_default_graph()
193
    tf.disable_v2_behavior()
194
    num_in = 10
195
    num_out = 15
196
    inp, out, ow = random_dense(num_in, num_out)
197
    env = blur_env.jp_env
198
    pre, post, synapse, _, network_spec = get_blur_state(env, inp, out, ow)
199

200
    genome = genome_util.create_backprop_genome(num_species=1)
201

202
    update1, update2 = blur.get_synaptic_update(
203
        pre,
204
        post,
205
        synapse=synapse,
206
        input_transform_gn=genome.neuron.transform,
207
        update_type=synapse_util.UpdateType.BOTH,
208
        env=env)
209

210
    hebbian_update = blur.get_hebbian_update(
211
        pre,
212
        post,
213
        genome.synapse.transform,
214
        global_spec=network_spec,
215
        env=env)
216
    grad_weights, grad_image, y = tf_gradients(inp, out, ow)
217

218
    np.set_printoptions(precision=4, linewidth=200)
219

220
    with tf.Session():
221
      verify_equal(update1, update2, hebbian_update, grad_weights.eval(),
222
                   grad_image.eval(), y.eval(), num_in, num_out)
223

224
  def test_get_synaptic_update_forward(self):
225
    tf.reset_default_graph()
226
    inp, out, w = random_dense()
227
    env = blur_env.tf_env
228
    pre, post, synapse, genome, _, = get_blur_state(env, inp, out, w)
229

230
    _, update_fwd = blur.get_synaptic_update(
231
        pre,
232
        post,
233
        synapse,
234
        input_transform_gn=genome.neuron.transform,
235
        update_type=synapse_util.UpdateType.FORWARD,
236
        env=env)
237

238
    inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))
239
    ww = w.astype(blur_env.NP_FLOATING_TYPE)
240
    inp_with_bias = tf.concat([inp[Ellipsis, 0], [[[[1]]]]], axis=-1)
241
    exp_results = inp_with_bias @ ww
242

243
    with tf.Session() as s:
244
      s.run(tf.initialize_all_variables())
245
      self.assertAllClose(update_fwd[Ellipsis, 0], exp_results)
246
      self.assertAllClose(update_fwd[Ellipsis, 1], exp_results)
247

248
  def test_network_step_mix_forward(self):
249
    spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=True)
250
    genome = genome_util.convert_genome_to_tf_variables(
251
        genome_util.create_backprop_genome(num_species=1))
252
    initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32)
253
    data = random_dataset()
254
    state = blur_meta.init_first_state(
255
        genome,
256
        synapse_initializer=initializer,
257
        data=data,
258
        hidden_layers=[256, 128])
259
    input_fn = data.make_one_shot_iterator().get_next
260
    data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn)
261
    blur.network_step(
262
        state, genome, data_support_fn, network_spec=spec, env=blur_env.tf_env)
263
    g = tf.get_default_graph()
264

265
    synapse_pre = g.get_operation_by_name(
266
        'step/backward/synapse_update/hebbian_pre').inputs[0]
267
    synapse_post = g.get_operation_by_name(
268
        'step/backward/synapse_update/hebbian_post').inputs[0]
269
    self.assertIn('forward', synapse_pre.name)
270
    self.assertIn('backward', synapse_post.name)
271

272
    self.assertNotIn('backward', synapse_pre.name)
273
    self.assertNotIn('forward', synapse_post.name)
274

275
  def test_network_step_no_mix_forward(self):
276
    spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=False)
277
    genome = genome_util.convert_genome_to_tf_variables(
278
        genome_util.create_backprop_genome(num_species=1))
279
    initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32)
280
    data = random_dataset()
281
    state = blur_meta.init_first_state(
282
        genome,
283
        synapse_initializer=initializer,
284
        data=data,
285
        hidden_layers=[256, 128])
286
    input_fn = data.make_one_shot_iterator().get_next
287
    data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn)
288
    blur.network_step(
289
        state,
290
        genome,
291
        data_support_fn,
292
        data.make_one_shot_iterator().get_next,
293
        network_spec=spec,
294
        env=blur_env.tf_env)
295
    g = tf.get_default_graph()
296

297
    synapse_pre = g.get_operation_by_name(
298
        'step/backward/synapse_update/hebbian_pre').inputs[0]
299
    synapse_post = g.get_operation_by_name(
300
        'step/backward/synapse_update/hebbian_post').inputs[0]
301
    self.assertIn('backward', synapse_pre.name)
302
    self.assertIn('backward', synapse_post.name)
303
    self.assertNotIn('forward', synapse_pre.name)
304
    self.assertNotIn('forward', synapse_post.name)
305

306
  def test_get_synaptic_update_backward(self):
307
    tf.reset_default_graph()
308
    tf.disable_v2_behavior()
309
    n_in, n_out = 10, 5
310
    inp, out, w = random_dense(n_in, n_out)
311
    env = blur_env.tf_env
312
    pre, post, synapse, genome, _ = get_blur_state(env, inp, out, w)
313

314
    update_bwd, _ = blur.get_synaptic_update(
315
        pre,
316
        post,
317
        synapse,
318
        input_transform_gn=genome.neuron.transform,
319
        update_type=synapse_util.UpdateType.BACKWARD,
320
        env=env)
321

322
    out = tf.constant(out.astype(blur_env.NP_FLOATING_TYPE))
323
    ww = w.astype(blur_env.NP_FLOATING_TYPE)
324
    exp_results = out[Ellipsis, 0] @ tf.transpose(ww, (0, 1, 3, 2))
325

326
    with tf.Session() as s:
327
      s.run(tf.initialize_all_variables())
328
      self.assertAllClose(update_bwd[Ellipsis, 0], tf.zeros((1, 1, 1, n_in)))
329
      self.assertAllClose(update_bwd[Ellipsis, 1], exp_results[Ellipsis, :-1])
330

331
  def test_neuron_update_fwd(self):
332
    tf.reset_default_graph()
333
    tf.disable_v2_behavior()
334
    n_in, n_out = 10, 5
335
    inp, out, w = random_dense(n_in, n_out)
336
    env = blur_env.tf_env
337
    pre, post, synapse, genome, network_spec = get_blur_state(env, inp, out, w)
338
    pre_fwd, post_fwd = blur.dense_neuron_update(
339
        pre,
340
        post,
341
        synapse,
342
        inp_act=None,
343
        out_act=sigmoid_with_grad,
344
        neuron_genome=genome.neuron,
345
        update_type=synapse_util.UpdateType.FORWARD,
346
        global_spec=network_spec,
347
        env=env)
348

349
    inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))
350
    ww = w.astype(blur_env.NP_FLOATING_TYPE)
351
    inp_with_bias = tf.concat([inp[Ellipsis, 0], [[[[1]]]]], axis=-1)
352
    exp_results = inp_with_bias @ ww
353

354
    with tf.Session() as s:
355
      s.run(tf.initialize_all_variables())
356
      self.assertAllClose(pre_fwd, pre)
357
      self.assertAllClose(post_fwd[Ellipsis, 0], tf.math.sigmoid(exp_results))
358
      self.assertAllClose(post_fwd[Ellipsis, 1],
359
                          d_sigmoid(out[Ellipsis, 0] + exp_results))
360

361
  def test_neuron_update_bwd(self):
362
    tf.reset_default_graph()
363
    tf.disable_v2_behavior()
364
    n_in, n_out = 10, 5
365
    inp, out, w = random_dense(n_in, n_out)
366
    env = blur_env.tf_env
367
    pre, post, synapse, genome, network_spec = get_blur_state(env, inp, out, w)
368

369
    inp_act_fn = lambda x: x
370
    out_act_fn = sigmoid_with_grad
371

372
    pre_fwd, post_fwd = blur.dense_neuron_update(
373
        pre,
374
        post,
375
        synapse,
376
        inp_act=None,
377
        out_act=out_act_fn,
378
        neuron_genome=genome.neuron,
379
        update_type=synapse_util.UpdateType.FORWARD,
380
        global_spec=network_spec,
381
        env=env)
382

383
    pre_bkw, _ = blur.dense_neuron_update(
384
        pre_fwd,
385
        post_fwd,
386
        synapse=synapse,
387
        inp_act=inp_act_fn,
388
        out_act=out_act_fn,
389
        neuron_genome=genome.neuron,
390
        update_type=synapse_util.UpdateType.BACKWARD,
391
        global_spec=network_spec,
392
        env=env)
393

394
    inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))
395
    ww = w.astype(blur_env.NP_FLOATING_TYPE)
396

397
    exp_result = post_fwd[Ellipsis, 1] @ tf.transpose(ww, (0, 1, 3, 2))
398

399
    with tf.Session() as s:
400
      s.run(tf.initialize_all_variables())
401
      self.assertAllClose(pre_bkw[Ellipsis, 0], pre_fwd[Ellipsis, 0])
402
      self.assertAllClose(pre_bkw[Ellipsis, 1],
403
                          exp_result[Ellipsis, :-1] * pre_fwd[Ellipsis, 1])
404

405
  def test_get_hebbian_update(self):
406
    tf.reset_default_graph()
407
    tf.disable_v2_behavior()
408
    n_in, n_out = 10, 5
409
    inp, out, w = random_dense(n_in, n_out)
410
    env = blur_env.tf_env
411
    pre, post, _, genome, network_spec = get_blur_state(env, inp, out, w)
412

413
    hebbian_update = blur.get_hebbian_update(pre, post,
414
                                             genome.synapse.transform,
415
                                             network_spec, env)
416
    hebbian_update_submatrix = synapse_util.synapse_submatrix(
417
        hebbian_update, n_in, synapse_util.UpdateType.FORWARD)
418

419
    inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))
420
    out = tf.constant(out.astype(blur_env.NP_FLOATING_TYPE))
421
    inp_transpose = tf.transpose(inp[Ellipsis, 0], (0, 1, 3, 2))
422
    exp_result = env.concat_row(inp_transpose) @ out[Ellipsis, 0]
423

424
    with tf.Session() as s:
425
      s.run(tf.initialize_all_variables())
426
      self.assertAllClose(hebbian_update_submatrix[Ellipsis, 0], exp_result)
427

428
  def test_synapse_derivative(self):
429
    tf.reset_default_graph()
430
    tf.disable_v2_behavior()
431
    n_in, n_out = 10, 5
432
    inp, out, w = random_dense(n_in, n_out)
433

434
    env = blur_env.tf_env
435
    pre, post, synapse, genome, network_spec = get_blur_state(env, inp, out, w)
436
    post = np.concatenate(
437
        2 * [np.zeros_like(out)], axis=-1).astype(blur_env.NP_FLOATING_TYPE)
438

439
    pre_fwd, post_fwd = blur.dense_neuron_update(
440
        pre,
441
        post,
442
        synapse,
443
        inp_act=None,
444
        out_act=sigmoid_with_grad,
445
        neuron_genome=genome.neuron,
446
        update_type=synapse_util.UpdateType.FORWARD,
447
        global_spec=network_spec,
448
        env=env)
449

450
    hebbian_update = blur.get_hebbian_update(pre_fwd, post_fwd,
451
                                             genome.synapse.transform,
452
                                             network_spec, env)
453

454
    hebbian_update_submatrix = synapse_util.synapse_submatrix(
455
        hebbian_update, n_in, synapse_util.UpdateType.FORWARD)
456

457
    inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))
458
    inp_with_bias = tf.concat([inp[Ellipsis, 0], [[[[1]]]]], axis=-1)
459
    ww = tf.constant(w.astype(blur_env.NP_FLOATING_TYPE))
460
    out = tf.nn.sigmoid(inp_with_bias @ ww)
461
    grad_w = tf.gradients(out, ww)
462

463
    with tf.Session() as s:
464
      s.run(tf.initialize_all_variables())
465
      hebb_update, grad_w_val = s.run(
466
          [hebbian_update_submatrix[Ellipsis, 0], grad_w[0]])
467
    self.assertAllClose(hebb_update, grad_w_val)
468

469

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.