google-research
253 строки · 7.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for genome handling."""
17
18import dataclasses as dc19import functools as ft20from typing import Any, Callable, Optional, Union21
22import numpy as np23import tensorflow.compat.v1 as tf24
25from blur import blur_env26
27
28
29Tensor = Union[tf.Tensor, np.ndarray]30
31
32@dc.dataclass33class NeuronGenome:34transform: Tensor35keep: Union[float, Tensor] = 1.036update: Union[float, Tensor] = 1.037norm_multiplier: Union[float, Tensor] = 1.038norm_shift: Union[float, Tensor] = 0.039
40
41@dc.dataclass42class HebbianTransform:43pre: Tensor44post: Tensor45ojas_multiplier: Union[float, Tensor] = 1.046
47
48@dc.dataclass49class SynapticGenome:50transform: HebbianTransform51synapse_init_std: Union[float, Tensor] = 1e-152synapse_init_xavier_std: Union[float, Tensor] = 0.053keep: Union[float, Tensor] = 1.054update: Union[float, Tensor] = 1.055saturation: Union[float, Tensor] = 156rescale_to: Union[float, Tensor] = 1.057
58
59@dc.dataclass60class Genome:61"""Genome."""62neuron: NeuronGenome63synapse: SynapticGenome64forward_synapse: Optional[SynapticGenome] = None65
66def num_states_per_neuron(self):67return get_num_states_in_genome(self)68
69def num_species(self):70return get_num_species_in_genome(self)71
72def __post_init__(self):73# By default we start with the same forward pass synapse genome that is74# used on the backward pass; whether to do synaptic weight update on the75# forward pass is decided in `network_step` based on the value of76# `forward_synapse_update` in the network specification.77if self.forward_synapse is None:78self.forward_synapse = self.synapse79
80
81def _safe_shape(t):82if hasattr(t, 'shape'):83return t.shape84else:85return np.array(t).shape86
87
88def get_num_states_in_genome(g):89return _safe_shape(g.synapse.transform.pre)[-1]90
91
92def transform_genome(g, map_fn, prefix=''):93"""Applies transformation to genome using map_fn."""94r = {}95for k, v in vars(g).items():96if dc.is_dataclass(v):97r[k] = transform_genome(v, map_fn=map_fn, prefix=f'{prefix}{k}/')98else:99mapped_value = map_fn(v, prefix + k)100if mapped_value is not None:101r[k] = mapped_value102return dc.replace(g, **r)103
104
105def copy_genome(genome):106return transform_genome(genome, lambda x, _: x)107
108
109def get_genome_slice(g, i):110def fn(x, unused_name):111# Necessary to avoid issues with tests restoring checkpoints.112if isinstance(x, int) or isinstance(x, float):113return x114return x[i]115return transform_genome(g, fn)116
117
118def get_genome(g, layer_index, per_layer_genome=False):119if per_layer_genome:120return get_genome_slice(g, layer_index)121else:122return g123
124
125def convert_genome_to_tf_variables(g, prefix=''):126"""Converts genome to tensorflow variables with initialized to constant."""127
128def map_fn(v, name):129return tf.Variable(initial_value=v, dtype=tf.float32, name=name)130
131return transform_genome(g, map_fn, prefix=prefix)132
133
134def convert_genome_to_dict(g):135res = {}136map_fn = lambda v, name: res.update([(name, v)])137transform_genome(g, map_fn)138return res139
140
141def _assign_from_values(v, name, values, index=None, prefix='', suffix=''):142key = prefix + name + suffix143if key not in values:144tf.logging.warning(f'Genome parameter "{key}" cannot be found in the '145'dictionary.')146return None147if hasattr(v, 'shape') and index is not None:148return values[key][index]149else:150return values[key]151
152
153def get_num_species_in_genome(g):154shape = _safe_shape(g.synapse.transform.pre)155return shape[0] if len(shape) == 3 else None156
157
158def genome_from_dict(values, index=None, prefix='', suffix=''):159num_states = _safe_shape(values['synapse/transform/pre'])[-1]160transform_fn = ft.partial(161_assign_from_values,162values=values,163index=index,164prefix=prefix,165suffix=suffix)166return transform_genome(create_random_genome(num_states), transform_fn)167
168
169def replicate_across_dims(value, shared_update_params, num_species, num_layers):170if num_species is not None and not shared_update_params:171value = np.array([value] * num_species)172if num_layers is not None:173value = np.array([value] * num_layers)174return value175
176
177def create_random_genome(num_states,178num_species=None,179shared_update_params=True,180neuron_transform_std=1.0,181synapse_transform_std=1.0,182synapse_update=-1e-3,183synapse_init_std=1e-1,184separate_forward_synapse=False,185num_layers=None):186"""Creates random genome with that many species."""187
188species_dims = (num_species,) if num_species is not None else ()189if num_layers is not None:190species_dims = (num_layers, *species_dims)191
192maybe_shared = ft.partial(replicate_across_dims,193shared_update_params=shared_update_params,194num_species=num_species,195num_layers=num_layers)196def _synaptic_genome(pre_transform, post_transform):197return SynapticGenome(198update=maybe_shared(synapse_update),199keep=maybe_shared(1.0),200synapse_init_std=maybe_shared(synapse_init_std),201synapse_init_xavier_std=maybe_shared(0.0),202saturation=maybe_shared(1.0),203rescale_to=maybe_shared(1.0),204transform=HebbianTransform(205pre=pre_transform,206post=post_transform,207ojas_multiplier=maybe_shared(1.0)))208
209matrix_shape = (*species_dims, num_states, num_states)210o = np.ones(matrix_shape)211z = np.zeros(matrix_shape)212init_matrix = lambda: np.random.randn(*matrix_shape) * synapse_transform_std213pre, post = init_matrix(), init_matrix()214g = Genome(215neuron=NeuronGenome(216transform=(217neuron_transform_std *218np.random.randn(*species_dims, 2 * num_states, 2 * num_states) *219np.block([[z, o], [o, z]])),220update=maybe_shared(1.0),221keep=maybe_shared(1.0),222norm_multiplier=maybe_shared(1.0),223norm_shift=maybe_shared(0.0)),224synapse=_synaptic_genome(pre, post))225if separate_forward_synapse:226fwd_pre, fwd_post = init_matrix(), init_matrix()227g.forward_synapse = _synaptic_genome(fwd_pre, fwd_post)228return g229
230
231
232
233# Neuron transformation matrix \mu before being fed to synapse
234# Rows describe contribution of corresponding state to all outputs
235# Columns describe of all inputs to a corresponding output
236#
237# row 0: sensory(i) ('pre')
238# row 1: feedback(i)
239# row 2: sensory(j) ('post')
240# row 3: feedback(j)
241_grad_neuron_genome = np.array(242[[0, 0, 1, 1],243[0, 0, 0, 0],244[1, 0, 0, 0],245[0, 1, 0, 0]], dtype=blur_env.NP_FLOATING_TYPE) # pyformat: disable246
247# ΔW(i, j, o) = Σ_{k, l} n(i, k) @ pre(i, o) @ post(o, l) @ n(j, l)
248# where n(i, k) is concatenation of input and output activations.
249_grad_hebbian_genome = HebbianTransform(250pre=np.array([[1, 0],251[0, 1]], dtype=blur_env.NP_FLOATING_TYPE),252post=np.array([[0, 1],253[1, 0]], dtype=blur_env.NP_FLOATING_TYPE))254