google-research

Форк
0
/
genome_util.py 
253 строки · 7.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for genome handling."""
17

18
import dataclasses as dc
19
import functools as ft
20
from typing import Any, Callable, Optional, Union
21

22
import numpy as np
23
import tensorflow.compat.v1 as tf
24

25
from blur import blur_env
26

27

28

29
Tensor = Union[tf.Tensor, np.ndarray]
30

31

32
@dc.dataclass
33
class NeuronGenome:
34
  transform: Tensor
35
  keep: Union[float, Tensor] = 1.0
36
  update: Union[float, Tensor] = 1.0
37
  norm_multiplier: Union[float, Tensor] = 1.0
38
  norm_shift: Union[float, Tensor] = 0.0
39

40

41
@dc.dataclass
42
class HebbianTransform:
43
  pre: Tensor
44
  post: Tensor
45
  ojas_multiplier: Union[float, Tensor] = 1.0
46

47

48
@dc.dataclass
49
class SynapticGenome:
50
  transform: HebbianTransform
51
  synapse_init_std: Union[float, Tensor] = 1e-1
52
  synapse_init_xavier_std: Union[float, Tensor] = 0.0
53
  keep: Union[float, Tensor] = 1.0
54
  update: Union[float, Tensor] = 1.0
55
  saturation: Union[float, Tensor] = 1
56
  rescale_to: Union[float, Tensor] = 1.0
57

58

59
@dc.dataclass
60
class Genome:
61
  """Genome."""
62
  neuron: NeuronGenome
63
  synapse: SynapticGenome
64
  forward_synapse: Optional[SynapticGenome] = None
65

66
  def num_states_per_neuron(self):
67
    return get_num_states_in_genome(self)
68

69
  def num_species(self):
70
    return get_num_species_in_genome(self)
71

72
  def __post_init__(self):
73
    # By default we start with the same forward pass synapse genome that is
74
    # used on the backward pass; whether to do synaptic weight update on the
75
    # forward pass is decided in `network_step` based on the value of
76
    # `forward_synapse_update` in the network specification.
77
    if self.forward_synapse is None:
78
      self.forward_synapse = self.synapse
79

80

81
def _safe_shape(t):
82
  if hasattr(t, 'shape'):
83
    return t.shape
84
  else:
85
    return np.array(t).shape
86

87

88
def get_num_states_in_genome(g):
89
  return _safe_shape(g.synapse.transform.pre)[-1]
90

91

92
def transform_genome(g, map_fn, prefix=''):
93
  """Applies transformation to genome using map_fn."""
94
  r = {}
95
  for k, v in vars(g).items():
96
    if dc.is_dataclass(v):
97
      r[k] = transform_genome(v, map_fn=map_fn, prefix=f'{prefix}{k}/')
98
    else:
99
      mapped_value = map_fn(v, prefix + k)
100
      if mapped_value is not None:
101
        r[k] = mapped_value
102
  return dc.replace(g, **r)
103

104

105
def copy_genome(genome):
106
  return transform_genome(genome, lambda x, _: x)
107

108

109
def get_genome_slice(g, i):
110
  def fn(x, unused_name):
111
    # Necessary to avoid issues with tests restoring checkpoints.
112
    if isinstance(x, int) or isinstance(x, float):
113
      return x
114
    return x[i]
115
  return transform_genome(g, fn)
116

117

118
def get_genome(g, layer_index, per_layer_genome=False):
119
  if per_layer_genome:
120
    return  get_genome_slice(g, layer_index)
121
  else:
122
    return g
123

124

125
def convert_genome_to_tf_variables(g, prefix=''):
126
  """Converts genome to tensorflow variables with initialized to constant."""
127

128
  def map_fn(v, name):
129
    return tf.Variable(initial_value=v, dtype=tf.float32, name=name)
130

131
  return transform_genome(g, map_fn, prefix=prefix)
132

133

134
def convert_genome_to_dict(g):
135
  res = {}
136
  map_fn = lambda v, name: res.update([(name, v)])
137
  transform_genome(g, map_fn)
138
  return res
139

140

141
def _assign_from_values(v, name, values, index=None, prefix='', suffix=''):
142
  key = prefix + name + suffix
143
  if key not in values:
144
    tf.logging.warning(f'Genome parameter "{key}" cannot be found in the '
145
                       'dictionary.')
146
    return None
147
  if hasattr(v, 'shape') and index is not None:
148
    return values[key][index]
149
  else:
150
    return values[key]
151

152

153
def get_num_species_in_genome(g):
154
  shape = _safe_shape(g.synapse.transform.pre)
155
  return shape[0] if len(shape) == 3 else None
156

157

158
def genome_from_dict(values, index=None, prefix='', suffix=''):
159
  num_states = _safe_shape(values['synapse/transform/pre'])[-1]
160
  transform_fn = ft.partial(
161
      _assign_from_values,
162
      values=values,
163
      index=index,
164
      prefix=prefix,
165
      suffix=suffix)
166
  return transform_genome(create_random_genome(num_states), transform_fn)
167

168

169
def replicate_across_dims(value, shared_update_params, num_species, num_layers):
170
  if num_species is not None and not shared_update_params:
171
    value = np.array([value] * num_species)
172
  if num_layers is not None:
173
    value = np.array([value] * num_layers)
174
  return value
175

176

177
def create_random_genome(num_states,
178
                         num_species=None,
179
                         shared_update_params=True,
180
                         neuron_transform_std=1.0,
181
                         synapse_transform_std=1.0,
182
                         synapse_update=-1e-3,
183
                         synapse_init_std=1e-1,
184
                         separate_forward_synapse=False,
185
                         num_layers=None):
186
  """Creates random genome with that many species."""
187

188
  species_dims = (num_species,) if num_species is not None else ()
189
  if num_layers is not None:
190
    species_dims = (num_layers, *species_dims)
191

192
  maybe_shared = ft.partial(replicate_across_dims,
193
                            shared_update_params=shared_update_params,
194
                            num_species=num_species,
195
                            num_layers=num_layers)
196
  def _synaptic_genome(pre_transform, post_transform):
197
    return SynapticGenome(
198
        update=maybe_shared(synapse_update),
199
        keep=maybe_shared(1.0),
200
        synapse_init_std=maybe_shared(synapse_init_std),
201
        synapse_init_xavier_std=maybe_shared(0.0),
202
        saturation=maybe_shared(1.0),
203
        rescale_to=maybe_shared(1.0),
204
        transform=HebbianTransform(
205
            pre=pre_transform,
206
            post=post_transform,
207
            ojas_multiplier=maybe_shared(1.0)))
208

209
  matrix_shape = (*species_dims, num_states, num_states)
210
  o = np.ones(matrix_shape)
211
  z = np.zeros(matrix_shape)
212
  init_matrix = lambda: np.random.randn(*matrix_shape) * synapse_transform_std
213
  pre, post = init_matrix(), init_matrix()
214
  g = Genome(
215
      neuron=NeuronGenome(
216
          transform=(
217
              neuron_transform_std *
218
              np.random.randn(*species_dims, 2 * num_states, 2 * num_states) *
219
              np.block([[z, o], [o, z]])),
220
          update=maybe_shared(1.0),
221
          keep=maybe_shared(1.0),
222
          norm_multiplier=maybe_shared(1.0),
223
          norm_shift=maybe_shared(0.0)),
224
      synapse=_synaptic_genome(pre, post))
225
  if separate_forward_synapse:
226
    fwd_pre, fwd_post = init_matrix(), init_matrix()
227
    g.forward_synapse = _synaptic_genome(fwd_pre, fwd_post)
228
  return g
229

230

231

232

233
# Neuron transformation matrix \mu before being fed to synapse
234
# Rows describe contribution of corresponding state to all outputs
235
# Columns describe of all inputs to a corresponding output
236
#
237
# row 0: sensory(i) ('pre')
238
# row 1: feedback(i)
239
# row 2: sensory(j) ('post')
240
# row 3: feedback(j)
241
_grad_neuron_genome = np.array(
242
    [[0, 0, 1, 1],
243
     [0, 0, 0, 0],
244
     [1, 0, 0, 0],
245
     [0, 1, 0, 0]], dtype=blur_env.NP_FLOATING_TYPE)  # pyformat: disable
246

247
# ΔW(i, j, o) = Σ_{k, l} n(i, k) @ pre(i, o) @ post(o, l) @ n(j, l)
248
# where n(i, k) is concatenation of input and output activations.
249
_grad_hebbian_genome = HebbianTransform(
250
    pre=np.array([[1, 0],
251
                  [0, 1]], dtype=blur_env.NP_FLOATING_TYPE),
252
    post=np.array([[0, 1],
253
                   [1, 0]], dtype=blur_env.NP_FLOATING_TYPE))
254

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.