google-research
467 строк · 15.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Unit tests for BLUR."""
17import numpy as np18import tensorflow.compat.v1 as tf19
20from blur import blur21from blur import blur_env22from blur import blur_meta23from blur import genome_util24from blur import synapse_util25
26
27def d_sigmoid(x):28s = tf.math.sigmoid(x)29return s * (1 - s)30
31
32def sigmoid_with_grad(x):33return tf.stack([tf.math.sigmoid(x[Ellipsis, 0]), d_sigmoid(x[Ellipsis, 1])], axis=-1)34
35
36def random_dataset():37n = 100038ds = tf.data.Dataset.from_tensor_slices({39'support': (40np.random.normal(0, 255,41size=(0, 1, 1, n,42784)).astype(blur_env.NP_FLOATING_TYPE),43np.random.randint(0, 10, size=(0, 1, 1, n,442)).astype(blur_env.NP_FLOATING_TYPE),45)})46return ds47
48
49def random_dense(num_in=10, num_out=5):50inp = np.random.random((1, 1, 1, num_in, 1))51out = np.random.random((1, 1, 1, num_out, 1))52
53# Neuron: input + [1] + output.54num_in_bias = num_in + 155ow = np.random.random((1, 1, num_in_bias, num_out))56
57return inp, out, ow58
59
60def get_blur_state(env, inp, out, ow):61pre = np.concatenate([inp, np.zeros_like(inp)],62axis=-1).astype(blur_env.NP_FLOATING_TYPE)63post = np.concatenate([np.zeros_like(out), out],64axis=-1).astype(blur_env.NP_FLOATING_TYPE)65ww = ow.astype(blur_env.NP_FLOATING_TYPE)66ww = ww[Ellipsis, None]67synapse = synapse_util.combine_in_out_synapses(68ww, synapse_util.transpose_synapse(ww, env), env=env)69synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2)70
71genome = genome_util.convert_genome_to_tf_variables(72genome_util.create_backprop_genome(num_species=1))73
74network_spec = blur.NetworkSpec()75network_spec.symmetric_synapses = True76network_spec.batch_average = False77network_spec.backward_update = 'multiplicative_second_state'78
79return pre, post, synapse, genome, network_spec80
81
82def tf_gradients(inp, out, ow):83tfinp = tf.constant(inp[Ellipsis, 0])84
85# Append '1' to the end of the input86bias = tf.constant([[[[1]]]], tfinp.dtype)87inp_with_bias = tf.concat([tfinp, bias], axis=-1)88tfw = tf.constant(ow)89y = inp_with_bias @ tfw90_, grad_weights, grad_image = tf.gradients(y, [bias, tfw, tfinp], out[Ellipsis, 0])91return grad_weights, grad_image, y92
93
94def verify_equal(update1, update2, hebbian_update, grad_weights, grad_image, y,95num_in, num_out):96# Hebbian update is [#in + #out, #in + #out, 2] matrix97# and it should look like this.98# Z U99# Z Z100#101# Z Z102# U^T Z103# Where Z is zeros104np.testing.assert_allclose(105hebbian_update[Ellipsis, 0],106np.swapaxes(hebbian_update[Ellipsis, 1], -1, -2),107rtol=1e-5)108
109np.testing.assert_allclose(110hebbian_update[Ellipsis, :num_in + 1, num_in + 1:, 0], grad_weights, rtol=1e-5)111
112np.testing.assert_allclose(hebbian_update[Ellipsis, :num_in + 1, :num_in, 0],113np.zeros((1, 1, num_in + 1, 10)))114np.testing.assert_allclose(hebbian_update[Ellipsis, num_in + 1:, num_in + 1:, 0],115np.zeros((1, 1, num_out, num_out)))116
117np.testing.assert_allclose(update1[Ellipsis, 1], grad_image, rtol=1e-5)118np.testing.assert_allclose(update1[Ellipsis, 0], np.zeros_like(update1[Ellipsis, 0]))119
120np.testing.assert_allclose(update2[Ellipsis, 0], y, rtol=1e-5)121
122
123class BlurTest(tf.test.TestCase):124
125def test_sync_in_out_synapses(self):126num_in = 3127num_out = 2128num_states = 2129env = blur_env.tf_env130in_out_synapse = tf.random.normal(shape=(num_in + 1, num_out, num_states))131out_in_synapse = tf.random.normal(shape=(num_out, num_in + 1, num_states))132
133synapse = synapse_util.combine_in_out_synapses(in_out_synapse,134out_in_synapse, env)135synapse_synced = synapse_util.sync_in_and_out_synapse(synapse, num_in, env)136fwd_sync_submatrix = synapse_util.synapse_submatrix(137synapse_synced,138num_in,139synapse_util.UpdateType.FORWARD,140include_bias=True)141
142bkw_sync_submatrix = synapse_util.synapse_submatrix(143synapse_synced,144num_in,145synapse_util.UpdateType.BACKWARD,146include_bias=True)147
148with tf.Session() as s:149bwd, fwd, inp = s.run([150synapse_util.transpose_synapse(bkw_sync_submatrix, env),151fwd_sync_submatrix, in_out_synapse152])153self.assertAllEqual(fwd, inp)154self.assertAllEqual(bwd, inp)155
156def test_verify_gradient_match_tf(self):157num_in = 10158num_out = 15159tf.reset_default_graph()160tf.disable_v2_behavior()161inp, out, ow = random_dense(num_in, num_out)162env = blur_env.tf_env163pre, post, synapse, _, network_spec = get_blur_state(env, inp, out, ow)164
165genome = genome_util.create_backprop_genome(num_species=1)166
167update1, update2 = blur.get_synaptic_update(168pre,169post,170synapse=synapse,171input_transform_gn=genome.neuron.transform,172update_type=synapse_util.UpdateType.BOTH,173env=env)174
175hebbian_update = blur.get_hebbian_update(176pre,177post,178genome.synapse.transform,179global_spec=network_spec,180env=env)181
182grad_weights, grad_image, y = tf_gradients(inp, out, ow)183
184np.set_printoptions(precision=4, linewidth=200)185
186with tf.Session():187verify_equal(update1.eval(), update2.eval(), hebbian_update.eval(),188grad_weights.eval(), grad_image.eval(), y.eval(), num_in,189num_out)190
191def test_verify_gradient_match_jp(self):192tf.reset_default_graph()193tf.disable_v2_behavior()194num_in = 10195num_out = 15196inp, out, ow = random_dense(num_in, num_out)197env = blur_env.jp_env198pre, post, synapse, _, network_spec = get_blur_state(env, inp, out, ow)199
200genome = genome_util.create_backprop_genome(num_species=1)201
202update1, update2 = blur.get_synaptic_update(203pre,204post,205synapse=synapse,206input_transform_gn=genome.neuron.transform,207update_type=synapse_util.UpdateType.BOTH,208env=env)209
210hebbian_update = blur.get_hebbian_update(211pre,212post,213genome.synapse.transform,214global_spec=network_spec,215env=env)216grad_weights, grad_image, y = tf_gradients(inp, out, ow)217
218np.set_printoptions(precision=4, linewidth=200)219
220with tf.Session():221verify_equal(update1, update2, hebbian_update, grad_weights.eval(),222grad_image.eval(), y.eval(), num_in, num_out)223
224def test_get_synaptic_update_forward(self):225tf.reset_default_graph()226inp, out, w = random_dense()227env = blur_env.tf_env228pre, post, synapse, genome, _, = get_blur_state(env, inp, out, w)229
230_, update_fwd = blur.get_synaptic_update(231pre,232post,233synapse,234input_transform_gn=genome.neuron.transform,235update_type=synapse_util.UpdateType.FORWARD,236env=env)237
238inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))239ww = w.astype(blur_env.NP_FLOATING_TYPE)240inp_with_bias = tf.concat([inp[Ellipsis, 0], [[[[1]]]]], axis=-1)241exp_results = inp_with_bias @ ww242
243with tf.Session() as s:244s.run(tf.initialize_all_variables())245self.assertAllClose(update_fwd[Ellipsis, 0], exp_results)246self.assertAllClose(update_fwd[Ellipsis, 1], exp_results)247
248def test_network_step_mix_forward(self):249spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=True)250genome = genome_util.convert_genome_to_tf_variables(251genome_util.create_backprop_genome(num_species=1))252initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32)253data = random_dataset()254state = blur_meta.init_first_state(255genome,256synapse_initializer=initializer,257data=data,258hidden_layers=[256, 128])259input_fn = data.make_one_shot_iterator().get_next260data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn)261blur.network_step(262state, genome, data_support_fn, network_spec=spec, env=blur_env.tf_env)263g = tf.get_default_graph()264
265synapse_pre = g.get_operation_by_name(266'step/backward/synapse_update/hebbian_pre').inputs[0]267synapse_post = g.get_operation_by_name(268'step/backward/synapse_update/hebbian_post').inputs[0]269self.assertIn('forward', synapse_pre.name)270self.assertIn('backward', synapse_post.name)271
272self.assertNotIn('backward', synapse_pre.name)273self.assertNotIn('forward', synapse_post.name)274
275def test_network_step_no_mix_forward(self):276spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=False)277genome = genome_util.convert_genome_to_tf_variables(278genome_util.create_backprop_genome(num_species=1))279initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32)280data = random_dataset()281state = blur_meta.init_first_state(282genome,283synapse_initializer=initializer,284data=data,285hidden_layers=[256, 128])286input_fn = data.make_one_shot_iterator().get_next287data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn)288blur.network_step(289state,290genome,291data_support_fn,292data.make_one_shot_iterator().get_next,293network_spec=spec,294env=blur_env.tf_env)295g = tf.get_default_graph()296
297synapse_pre = g.get_operation_by_name(298'step/backward/synapse_update/hebbian_pre').inputs[0]299synapse_post = g.get_operation_by_name(300'step/backward/synapse_update/hebbian_post').inputs[0]301self.assertIn('backward', synapse_pre.name)302self.assertIn('backward', synapse_post.name)303self.assertNotIn('forward', synapse_pre.name)304self.assertNotIn('forward', synapse_post.name)305
306def test_get_synaptic_update_backward(self):307tf.reset_default_graph()308tf.disable_v2_behavior()309n_in, n_out = 10, 5310inp, out, w = random_dense(n_in, n_out)311env = blur_env.tf_env312pre, post, synapse, genome, _ = get_blur_state(env, inp, out, w)313
314update_bwd, _ = blur.get_synaptic_update(315pre,316post,317synapse,318input_transform_gn=genome.neuron.transform,319update_type=synapse_util.UpdateType.BACKWARD,320env=env)321
322out = tf.constant(out.astype(blur_env.NP_FLOATING_TYPE))323ww = w.astype(blur_env.NP_FLOATING_TYPE)324exp_results = out[Ellipsis, 0] @ tf.transpose(ww, (0, 1, 3, 2))325
326with tf.Session() as s:327s.run(tf.initialize_all_variables())328self.assertAllClose(update_bwd[Ellipsis, 0], tf.zeros((1, 1, 1, n_in)))329self.assertAllClose(update_bwd[Ellipsis, 1], exp_results[Ellipsis, :-1])330
331def test_neuron_update_fwd(self):332tf.reset_default_graph()333tf.disable_v2_behavior()334n_in, n_out = 10, 5335inp, out, w = random_dense(n_in, n_out)336env = blur_env.tf_env337pre, post, synapse, genome, network_spec = get_blur_state(env, inp, out, w)338pre_fwd, post_fwd = blur.dense_neuron_update(339pre,340post,341synapse,342inp_act=None,343out_act=sigmoid_with_grad,344neuron_genome=genome.neuron,345update_type=synapse_util.UpdateType.FORWARD,346global_spec=network_spec,347env=env)348
349inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))350ww = w.astype(blur_env.NP_FLOATING_TYPE)351inp_with_bias = tf.concat([inp[Ellipsis, 0], [[[[1]]]]], axis=-1)352exp_results = inp_with_bias @ ww353
354with tf.Session() as s:355s.run(tf.initialize_all_variables())356self.assertAllClose(pre_fwd, pre)357self.assertAllClose(post_fwd[Ellipsis, 0], tf.math.sigmoid(exp_results))358self.assertAllClose(post_fwd[Ellipsis, 1],359d_sigmoid(out[Ellipsis, 0] + exp_results))360
361def test_neuron_update_bwd(self):362tf.reset_default_graph()363tf.disable_v2_behavior()364n_in, n_out = 10, 5365inp, out, w = random_dense(n_in, n_out)366env = blur_env.tf_env367pre, post, synapse, genome, network_spec = get_blur_state(env, inp, out, w)368
369inp_act_fn = lambda x: x370out_act_fn = sigmoid_with_grad371
372pre_fwd, post_fwd = blur.dense_neuron_update(373pre,374post,375synapse,376inp_act=None,377out_act=out_act_fn,378neuron_genome=genome.neuron,379update_type=synapse_util.UpdateType.FORWARD,380global_spec=network_spec,381env=env)382
383pre_bkw, _ = blur.dense_neuron_update(384pre_fwd,385post_fwd,386synapse=synapse,387inp_act=inp_act_fn,388out_act=out_act_fn,389neuron_genome=genome.neuron,390update_type=synapse_util.UpdateType.BACKWARD,391global_spec=network_spec,392env=env)393
394inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))395ww = w.astype(blur_env.NP_FLOATING_TYPE)396
397exp_result = post_fwd[Ellipsis, 1] @ tf.transpose(ww, (0, 1, 3, 2))398
399with tf.Session() as s:400s.run(tf.initialize_all_variables())401self.assertAllClose(pre_bkw[Ellipsis, 0], pre_fwd[Ellipsis, 0])402self.assertAllClose(pre_bkw[Ellipsis, 1],403exp_result[Ellipsis, :-1] * pre_fwd[Ellipsis, 1])404
405def test_get_hebbian_update(self):406tf.reset_default_graph()407tf.disable_v2_behavior()408n_in, n_out = 10, 5409inp, out, w = random_dense(n_in, n_out)410env = blur_env.tf_env411pre, post, _, genome, network_spec = get_blur_state(env, inp, out, w)412
413hebbian_update = blur.get_hebbian_update(pre, post,414genome.synapse.transform,415network_spec, env)416hebbian_update_submatrix = synapse_util.synapse_submatrix(417hebbian_update, n_in, synapse_util.UpdateType.FORWARD)418
419inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))420out = tf.constant(out.astype(blur_env.NP_FLOATING_TYPE))421inp_transpose = tf.transpose(inp[Ellipsis, 0], (0, 1, 3, 2))422exp_result = env.concat_row(inp_transpose) @ out[Ellipsis, 0]423
424with tf.Session() as s:425s.run(tf.initialize_all_variables())426self.assertAllClose(hebbian_update_submatrix[Ellipsis, 0], exp_result)427
428def test_synapse_derivative(self):429tf.reset_default_graph()430tf.disable_v2_behavior()431n_in, n_out = 10, 5432inp, out, w = random_dense(n_in, n_out)433
434env = blur_env.tf_env435pre, post, synapse, genome, network_spec = get_blur_state(env, inp, out, w)436post = np.concatenate(4372 * [np.zeros_like(out)], axis=-1).astype(blur_env.NP_FLOATING_TYPE)438
439pre_fwd, post_fwd = blur.dense_neuron_update(440pre,441post,442synapse,443inp_act=None,444out_act=sigmoid_with_grad,445neuron_genome=genome.neuron,446update_type=synapse_util.UpdateType.FORWARD,447global_spec=network_spec,448env=env)449
450hebbian_update = blur.get_hebbian_update(pre_fwd, post_fwd,451genome.synapse.transform,452network_spec, env)453
454hebbian_update_submatrix = synapse_util.synapse_submatrix(455hebbian_update, n_in, synapse_util.UpdateType.FORWARD)456
457inp = tf.constant(inp.astype(blur_env.NP_FLOATING_TYPE))458inp_with_bias = tf.concat([inp[Ellipsis, 0], [[[[1]]]]], axis=-1)459ww = tf.constant(w.astype(blur_env.NP_FLOATING_TYPE))460out = tf.nn.sigmoid(inp_with_bias @ ww)461grad_w = tf.gradients(out, ww)462
463with tf.Session() as s:464s.run(tf.initialize_all_variables())465hebb_update, grad_w_val = s.run(466[hebbian_update_submatrix[Ellipsis, 0], grad_w[0]])467self.assertAllClose(hebb_update, grad_w_val)468
469