google-research

Форк
0
/
synapse_util.py 
260 строк · 8.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for synapse handling."""
17

18
import dataclasses as dc
19
import enum
20
import functools as ft
21
from typing import Callable, List, Sequence, Text, Union, Optional
22

23
import jax.numpy as jp
24
import numpy as np
25
import tensorflow.compat.v1 as tf
26

27
from blur import blur_env
28

29
TensorShape = tf.TensorShape
30
Tensor = Union[tf.Tensor, np.ndarray, jp.ndarray]
31

32

33
@dc.dataclass
34
class SynapseInitializerParams:
35
  shape: TensorShape
36
  in_neurons: int
37
  out_neurons: int
38

39

40
class UpdateType(enum.Enum):
41
  FORWARD = 1
42
  BACKWARD = 2
43
  BOTH = 3
44
  NONE = 4
45

46

47
SynapseInitializer = Callable[[SynapseInitializerParams], Tensor]
48

49
# A callable that takes a sequence of layers and SynapseInitializer and creates
50
# appropriately shaped list of Synapses.
51
CreateSynapseFn = Callable[[Sequence[Tensor], SynapseInitializer], List[Tensor]]
52

53

54
def random_uniform_symmetric(shape, seed):
55
  return (tf.random.uniform(shape, seed=seed) - 0.5) * 2
56

57

58
def random_initializer(start_seed=0,
59
                       scale_by_channels=False,
60
                       scale=1,
61
                       bias=0,
62
                       random_fn=random_uniform_symmetric):
63
  """Returns initializer that generates random sequence."""
64
  seed = [hash(str(start_seed))]
65

66
  def impl(params):
67
    if len(params.shape) >= 3:
68
      # shape: species x (in+out) x (in+out) x states
69
      num_channels = int(params.shape[-2])
70
    seed[0] += 1
71
    v = random_fn(params.shape, seed[0])
72
    apply_scale = scale(params) if callable(scale) else scale
73
    r = v * apply_scale + bias
74
    if scale_by_channels:
75
      r = r / (num_channels**0.5)
76
    return r
77

78
  return impl
79

80

81
def _random_uniform_fn(start_seed):
82
  rng = np.random.RandomState(start_seed)
83
  return lambda shape: tf.constant(  # pylint: disable=g-long-lambda
84
      rng.uniform(low=-1, high=1, size=shape), dtype=np.float32)
85

86

87
def fixed_random_initializer(start_seed=0,
88
                             scale_by_channels=False,
89
                             scale=1,
90
                             bias=0,
91
                             random_fn=None):
92
  """Returns an initializer that generates random (but fixed) sequence.
93

94
  The resulting tensors are backed by a constant so they produce the same
95
  value across all calls.
96

97
  This initializer uses its own random state that is independent of default
98
  random sequence.
99

100
  Args:
101
    start_seed: initial seed passed to np.random.RandomStates
102
    scale_by_channels:  whether to scale by number of channels.
103
    scale: target scale (default: 1)
104
    bias: mean of the resulting distribution.
105
    random_fn: random generator if none will use use _random_uniform_fn
106

107
  Returns:
108
    callable that accepts shape and returns tensorflow constant tensor.
109
  """
110
  if random_fn is None:
111
    random_fn = _random_uniform_fn(start_seed)
112

113
  def impl(params):
114
    if len(params.shape) >= 3:
115
      # shape: species x (in+out) x (in+out) x states
116
      num_channels = int(params.shape[-2])
117
    v = random_fn(shape=params.shape)
118
    apply_scale = scale(params) if callable(scale) else scale
119
    r = v * apply_scale + bias
120
    if scale_by_channels:
121
      r = r / (num_channels**0.5)
122
    return r
123

124
  return impl
125

126

127
def create_synapse_init_fns(
128
    layers,
129
    initializer):
130
  """Generates network synapse initializers.
131

132
  Arguments:
133
    layers: Sequence of network layers (used for shape calculation).
134
    initializer: SynapseInitializer used to initialize synapse tensors.
135

136
  Returns:
137
    A list of functions that produce synapse tensors for all layers upon
138
    execution.
139
  """
140
  synapse_init_fns = []
141
  for pre, post in zip(layers, layers[1:]):
142
    # shape: population_dims, batch_size, in_channels, neuron_state
143
    pop_dims = pre.shape[:-3]
144
    # -2: is the number of channels
145
    num_inputs = pre.shape[-2] + post.shape[-2] + 1
146
    # -1: is the number of states in a single neuron.
147
    synapse_shape = (*pop_dims, num_inputs, num_inputs, pre.shape[-1])
148
    params = SynapseInitializerParams(
149
        shape=synapse_shape,
150
        in_neurons=pre.shape[-2],
151
        out_neurons=post.shape[-2])
152
    synapse_init_fns.append(ft.partial(initializer, params))
153
  return synapse_init_fns
154

155

156
def create_synapses(layers,
157
                    initializer):
158
  """Generates arbitrary form synapses.
159

160
  Arguments:
161
    layers: Sequence of network layers (used for shape calculation).
162
    initializer: SynapseInitializer used to initialize synapse tensors.
163

164
  Returns:
165
    A list of created synapse tensors for all layers.
166
  """
167
  return [init_fn() for init_fn in create_synapse_init_fns(layers, initializer)]
168

169

170
def transpose_synapse(synapse, env):
171
  num_batch_dims = len(synapse.shape[:-3])
172
  perm = [
173
      *range(num_batch_dims), num_batch_dims + 1, num_batch_dims,
174
      num_batch_dims + 2
175
  ]
176
  return env.transpose(synapse, perm)
177

178

179
def synapse_submatrix(synapse,
180
                      in_channels,
181
                      update_type,
182
                      include_bias = True):
183
  """Returns a submatrix of a synapse matrix given the update type."""
184
  bias = 1 if include_bias else 0
185
  if update_type == UpdateType.FORWARD:
186
    return synapse[Ellipsis, :(in_channels + bias), (in_channels + bias):, :]
187
  if update_type == UpdateType.BACKWARD:
188
    return synapse[Ellipsis, (in_channels + 1):, :(in_channels + bias), :]
189

190

191
def combine_in_out_synapses(in_out_synapse, out_in_synapse,
192
                            env):
193
  """Combines forward and backward synapses into a single matrix."""
194
  batch_dims = in_out_synapse.shape[:-3]
195
  out_channels, in_channels, num_states = in_out_synapse.shape[-3:]
196
  synapse = env.concat([
197
      env.concat([
198
          env.zeros((*batch_dims, out_channels, out_channels, num_states)),
199
          in_out_synapse
200
      ],
201
                 axis=-2),
202
      env.concat([
203
          out_in_synapse,
204
          env.zeros((*batch_dims, in_channels, in_channels, num_states))
205
      ],
206
                 axis=-2)
207
  ],
208
                       axis=-3)
209
  return synapse
210

211

212
def sync_all_synapses(synapses, layers, env):
213
  """Sync synapses across all layers.
214

215
  For each synapse, syncs its first state forward synapse with backward synapse
216
  and copies it arocess all the states.
217

218
  Args:
219
    synapses: list of synapses in the network.
220
    layers: list of layers in the network.
221
    env: Environment
222

223
  Returns:
224
    Synchronized synapses.
225
  """
226
  for i in range(len(synapses)):
227
    synapses[i] = sync_in_and_out_synapse(synapses[i], layers[i].shape[-2], env)
228
  return synapses
229

230

231
def sync_in_and_out_synapse(synapse, in_channels, env):
232
  """Copies forward synapse to backward one."""
233
  in_out_synapse = synapse_submatrix(  # pytype: disable=wrong-arg-types  # use-enum-overlay
234
      synapse,
235
      in_channels=in_channels,
236
      update_type=UpdateType.FORWARD,
237
      include_bias=True)
238
  return combine_in_out_synapses(in_out_synapse,
239
                                 transpose_synapse(in_out_synapse, env), env)
240

241

242
def sync_states_synapse(synapse, env, num_states=None):
243
  """Sync synapse's first state across all the other states."""
244
  if num_states is None:
245
    num_states = synapse.shape[-1]
246
  return env.stack(num_states * [synapse[Ellipsis, 0]], axis=-1)
247

248

249
def normalize_synapses(synapses,
250
                       rescale_to,
251
                       env,
252
                       axis = -3):
253
  """Normalizes synapses across a particular axis (across input by def.)."""
254
  # Default value axis=-3 corresponds to normalizing across the input neuron
255
  # dimension.
256
  squared = env.sum(synapses**2, axis=axis, keepdims=True)
257
  synapses /= env.sqrt(squared + 1e-9)
258
  if rescale_to is not None:
259
    synapses *= rescale_to
260
  return synapses
261

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.