google-research
260 строк · 8.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for synapse handling."""
17
18import dataclasses as dc19import enum20import functools as ft21from typing import Callable, List, Sequence, Text, Union, Optional22
23import jax.numpy as jp24import numpy as np25import tensorflow.compat.v1 as tf26
27from blur import blur_env28
29TensorShape = tf.TensorShape30Tensor = Union[tf.Tensor, np.ndarray, jp.ndarray]31
32
33@dc.dataclass34class SynapseInitializerParams:35shape: TensorShape36in_neurons: int37out_neurons: int38
39
40class UpdateType(enum.Enum):41FORWARD = 142BACKWARD = 243BOTH = 344NONE = 445
46
47SynapseInitializer = Callable[[SynapseInitializerParams], Tensor]48
49# A callable that takes a sequence of layers and SynapseInitializer and creates
50# appropriately shaped list of Synapses.
51CreateSynapseFn = Callable[[Sequence[Tensor], SynapseInitializer], List[Tensor]]52
53
54def random_uniform_symmetric(shape, seed):55return (tf.random.uniform(shape, seed=seed) - 0.5) * 256
57
58def random_initializer(start_seed=0,59scale_by_channels=False,60scale=1,61bias=0,62random_fn=random_uniform_symmetric):63"""Returns initializer that generates random sequence."""64seed = [hash(str(start_seed))]65
66def impl(params):67if len(params.shape) >= 3:68# shape: species x (in+out) x (in+out) x states69num_channels = int(params.shape[-2])70seed[0] += 171v = random_fn(params.shape, seed[0])72apply_scale = scale(params) if callable(scale) else scale73r = v * apply_scale + bias74if scale_by_channels:75r = r / (num_channels**0.5)76return r77
78return impl79
80
81def _random_uniform_fn(start_seed):82rng = np.random.RandomState(start_seed)83return lambda shape: tf.constant( # pylint: disable=g-long-lambda84rng.uniform(low=-1, high=1, size=shape), dtype=np.float32)85
86
87def fixed_random_initializer(start_seed=0,88scale_by_channels=False,89scale=1,90bias=0,91random_fn=None):92"""Returns an initializer that generates random (but fixed) sequence.93
94The resulting tensors are backed by a constant so they produce the same
95value across all calls.
96
97This initializer uses its own random state that is independent of default
98random sequence.
99
100Args:
101start_seed: initial seed passed to np.random.RandomStates
102scale_by_channels: whether to scale by number of channels.
103scale: target scale (default: 1)
104bias: mean of the resulting distribution.
105random_fn: random generator if none will use use _random_uniform_fn
106
107Returns:
108callable that accepts shape and returns tensorflow constant tensor.
109"""
110if random_fn is None:111random_fn = _random_uniform_fn(start_seed)112
113def impl(params):114if len(params.shape) >= 3:115# shape: species x (in+out) x (in+out) x states116num_channels = int(params.shape[-2])117v = random_fn(shape=params.shape)118apply_scale = scale(params) if callable(scale) else scale119r = v * apply_scale + bias120if scale_by_channels:121r = r / (num_channels**0.5)122return r123
124return impl125
126
127def create_synapse_init_fns(128layers,129initializer):130"""Generates network synapse initializers.131
132Arguments:
133layers: Sequence of network layers (used for shape calculation).
134initializer: SynapseInitializer used to initialize synapse tensors.
135
136Returns:
137A list of functions that produce synapse tensors for all layers upon
138execution.
139"""
140synapse_init_fns = []141for pre, post in zip(layers, layers[1:]):142# shape: population_dims, batch_size, in_channels, neuron_state143pop_dims = pre.shape[:-3]144# -2: is the number of channels145num_inputs = pre.shape[-2] + post.shape[-2] + 1146# -1: is the number of states in a single neuron.147synapse_shape = (*pop_dims, num_inputs, num_inputs, pre.shape[-1])148params = SynapseInitializerParams(149shape=synapse_shape,150in_neurons=pre.shape[-2],151out_neurons=post.shape[-2])152synapse_init_fns.append(ft.partial(initializer, params))153return synapse_init_fns154
155
156def create_synapses(layers,157initializer):158"""Generates arbitrary form synapses.159
160Arguments:
161layers: Sequence of network layers (used for shape calculation).
162initializer: SynapseInitializer used to initialize synapse tensors.
163
164Returns:
165A list of created synapse tensors for all layers.
166"""
167return [init_fn() for init_fn in create_synapse_init_fns(layers, initializer)]168
169
170def transpose_synapse(synapse, env):171num_batch_dims = len(synapse.shape[:-3])172perm = [173*range(num_batch_dims), num_batch_dims + 1, num_batch_dims,174num_batch_dims + 2175]176return env.transpose(synapse, perm)177
178
179def synapse_submatrix(synapse,180in_channels,181update_type,182include_bias = True):183"""Returns a submatrix of a synapse matrix given the update type."""184bias = 1 if include_bias else 0185if update_type == UpdateType.FORWARD:186return synapse[Ellipsis, :(in_channels + bias), (in_channels + bias):, :]187if update_type == UpdateType.BACKWARD:188return synapse[Ellipsis, (in_channels + 1):, :(in_channels + bias), :]189
190
191def combine_in_out_synapses(in_out_synapse, out_in_synapse,192env):193"""Combines forward and backward synapses into a single matrix."""194batch_dims = in_out_synapse.shape[:-3]195out_channels, in_channels, num_states = in_out_synapse.shape[-3:]196synapse = env.concat([197env.concat([198env.zeros((*batch_dims, out_channels, out_channels, num_states)),199in_out_synapse
200],201axis=-2),202env.concat([203out_in_synapse,204env.zeros((*batch_dims, in_channels, in_channels, num_states))205],206axis=-2)207],208axis=-3)209return synapse210
211
212def sync_all_synapses(synapses, layers, env):213"""Sync synapses across all layers.214
215For each synapse, syncs its first state forward synapse with backward synapse
216and copies it arocess all the states.
217
218Args:
219synapses: list of synapses in the network.
220layers: list of layers in the network.
221env: Environment
222
223Returns:
224Synchronized synapses.
225"""
226for i in range(len(synapses)):227synapses[i] = sync_in_and_out_synapse(synapses[i], layers[i].shape[-2], env)228return synapses229
230
231def sync_in_and_out_synapse(synapse, in_channels, env):232"""Copies forward synapse to backward one."""233in_out_synapse = synapse_submatrix( # pytype: disable=wrong-arg-types # use-enum-overlay234synapse,235in_channels=in_channels,236update_type=UpdateType.FORWARD,237include_bias=True)238return combine_in_out_synapses(in_out_synapse,239transpose_synapse(in_out_synapse, env), env)240
241
242def sync_states_synapse(synapse, env, num_states=None):243"""Sync synapse's first state across all the other states."""244if num_states is None:245num_states = synapse.shape[-1]246return env.stack(num_states * [synapse[Ellipsis, 0]], axis=-1)247
248
249def normalize_synapses(synapses,250rescale_to,251env,252axis = -3):253"""Normalizes synapses across a particular axis (across input by def.)."""254# Default value axis=-3 corresponds to normalizing across the input neuron255# dimension.256squared = env.sum(synapses**2, axis=axis, keepdims=True)257synapses /= env.sqrt(squared + 1e-9)258if rescale_to is not None:259synapses *= rescale_to260return synapses261