google-research

Форк
0
375 строк · 11.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=logging-format-interpolation
17
# pylint: disable=unused-import
18
# pylint: disable=invalid-unary-operand-type
19
# pylint: disable=g-long-lambda
20
# pylint: disable=g-direct-tensorflow-import
21

22
r"""Custom ops."""
23

24
import os
25
import sys
26

27
from absl import logging
28
import numpy as np
29
import tensorflow.compat.v1 as tf
30

31
from differentiable_data_selection import common_utils
32

33
USE_BFLOAT16 = False
34

35

36
def floatx():
37
  return tf.bfloat16 if USE_BFLOAT16 else tf.float32
38

39

40
def use_bfloat16():
41
  global USE_BFLOAT16
42
  USE_BFLOAT16 = True
43

44

45
def get_variable(name, shape, initializer, trainable=True,
46
                 convert_if_using_bfloat16=True):
47
  """Create variable and convert to `tf.bfloat16` if needed."""
48
  w = tf.get_variable(name=name,
49
                      shape=shape,
50
                      initializer=initializer,
51
                      trainable=trainable,
52
                      use_resource=True)
53
  if USE_BFLOAT16 and convert_if_using_bfloat16:
54
    w = tf.cast(w, tf.bfloat16)
55
  return w
56

57

58
def log_tensor(x, training):
59
  """Prints a tensor."""
60
  if training:
61
    logging.info(f'{x.name:<90} {x.device} {x.shape}')
62

63

64
def _conv_kernel_initializer(shape, dtype=None, partition_info=None):
65
  """Initialization for convolutional kernels."""
66
  del partition_info
67
  kernel_height, kernel_width, _, out_filters = shape[-4:]
68
  fan_out = int(kernel_height * kernel_width * out_filters)
69
  return tf.random.normal(
70
      shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype)
71

72

73
def _dense_kernel_initializer(shape, dtype=None, partition_info=None):
74
  """Initialization for dense kernels."""
75
  del partition_info
76
  init_range = 1. / np.sqrt(shape[-1])
77
  return tf.random_uniform(shape, -init_range, init_range, dtype=dtype)
78

79

80
def conv2d(x, filter_size, num_out_filters, stride=1,
81
           use_bias=False, padding='SAME', data_format='NHWC', name='conv2d',
82
           w=None, b=None):
83
  """Conv."""
84
  with tf.variable_scope(name):
85
    num_inp_filters = x.shape[-1].value
86

87
    w = tf.get_variable(
88
        name='kernel',
89
        shape=[filter_size, filter_size, num_inp_filters, num_out_filters],
90
        initializer=_conv_kernel_initializer,
91
        trainable=True,
92
        use_resource=True)
93

94
    if USE_BFLOAT16:
95
      w = tf.cast(w, tf.bfloat16)
96
    x = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1],
97
                     padding=padding, data_format=data_format)
98

99
    if use_bias:
100
      if b is None:
101
        b = tf.get_variable(
102
            name='bias',
103
            shape=[num_out_filters],
104
            initializer=tf.initializers.zeros(),
105
            trainable=True,
106
            use_resource=True)
107
        if USE_BFLOAT16:
108
          b = tf.cast(b, tf.bfloat16)
109
      x = tf.nn.bias_add(x, b, name='bias_add')
110
    return x
111

112

113
def dense(x, num_outputs, use_bias=True, name='dense'):
114
  """Custom fully connected layer."""
115
  num_inputs = x.shape[-1].value
116
  with tf.variable_scope(name):
117
    w = tf.get_variable(
118
        name='kernel',
119
        shape=[num_inputs, num_outputs],
120
        initializer=_dense_kernel_initializer,
121
        trainable=True,
122
        use_resource=True)
123
    if USE_BFLOAT16:
124
      w = tf.cast(w, tf.bfloat16)
125

126
    x = tf.linalg.matmul(x, w)
127
    if use_bias:
128
      b = tf.get_variable(
129
          name='bias',
130
          shape=[num_outputs],
131
          initializer=tf.initializers.zeros(),
132
          trainable=True,
133
          use_resource=True)
134
      if USE_BFLOAT16:
135
        b = tf.cast(b, tf.bfloat16)
136
      x = tf.nn.bias_add(x, b, name='bias_add')
137
    return x
138

139

140
def avg_pool(x, filter_size, stride, padding='SAME', name='avg_pool'):
141
  """Avg pool."""
142
  x = tf.nn.avg_pool(
143
      x,
144
      ksize=[filter_size, filter_size],
145
      strides=[1, stride, stride, 1],
146
      padding=padding,
147
      name=name)
148
  return x
149

150

151
def max_pool(x, filter_size, stride, padding='SAME', name='max_pool'):
152
  """Avg pool."""
153
  x = tf.nn.max_pool(
154
      x,
155
      ksize=[filter_size, filter_size],
156
      strides=[1, stride, stride, 1],
157
      padding=padding,
158
      name=name)
159
  return x
160

161

162
def relu(x, leaky=0.2, name='relu'):
163
  """Leaky ReLU."""
164
  return tf.nn.leaky_relu(x, alpha=leaky, name=name)
165

166

167
def batch_norm(x, params, training, name='batch_norm', **kwargs):
168
  """Wrapped `batch_norm`."""
169
  return sync_batch_norm(x, params, training, name=name, **kwargs)
170

171

172
def sync_batch_norm(x, params, training, name='batch_norm'):
173
  """Sync batch_norm."""
174
  size = x.shape[-1].value
175

176
  with tf.variable_scope(name):
177
    gamma = tf.get_variable(name='gamma',
178
                            shape=[size],
179
                            initializer=tf.initializers.ones(),
180
                            trainable=True)
181
    beta = tf.get_variable(name='beta',
182
                           shape=[size],
183
                           initializer=tf.initializers.zeros(),
184
                           trainable=True)
185
    moving_mean = tf.get_variable(name='moving_mean',
186
                                  shape=[size],
187
                                  initializer=tf.initializers.zeros(),
188
                                  trainable=False)
189
    moving_variance = tf.get_variable(name='moving_variance',
190
                                      shape=[size],
191
                                      initializer=tf.initializers.ones(),
192
                                      trainable=False)
193

194
  x = tf.cast(x, tf.float32)
195
  if training:
196
    if params.use_tpu:
197
      num_replicas = params.num_replicas
198
      if num_replicas <= 8:
199
        group_assign = None
200
        group_shards = tf.cast(num_replicas, tf.float32)
201
      else:
202
        group_shards = max(8, num_replicas // 8)
203

204
        # round to nearest power of 2
205
        log_num_replicas = max(1, int(np.log(group_shards) / np.log(2.)))
206
        group_shards = int(np.power(2., log_num_replicas))
207

208
        group_assign = np.arange(num_replicas, dtype=np.int32)
209
        group_assign = group_assign.reshape([-1, group_shards])
210
        group_assign = group_assign.tolist()
211
        group_shards = tf.cast(group_shards, tf.float32)
212

213
      mean = tf.reduce_mean(x, [0, 1, 2])
214
      mean = tf.tpu.cross_replica_sum(mean / group_shards, group_assign)
215

216
      # Var[x] = E[x^2] - E[x]^2
217
      mean_sq = tf.reduce_mean(tf.math.square(x), [0, 1, 2])
218
      mean_sq = tf.tpu.cross_replica_sum(mean_sq / group_shards, group_assign)
219
      variance = mean_sq - tf.math.square(mean)
220
    else:
221
      mean, variance = tf.nn.moments(x, [0, 1, 2])
222

223
    x = tf.nn.batch_normalization(
224
        x, mean=mean, variance=variance, offset=beta, scale=gamma,
225
        variance_epsilon=params.batch_norm_epsilon)
226

227
    if USE_BFLOAT16:
228
      x = tf.cast(x, tf.bfloat16, name='batch_norm_recast')
229

230
    if (isinstance(moving_mean, tf.Variable) and
231
        isinstance(moving_variance, tf.Variable)):
232
      decay = tf.cast(1. - params.batch_norm_decay, tf.float32)
233
      def u(moving, normal, name):
234
        if params.use_tpu:
235
          num_replicas_fp = tf.cast(params.num_replicas, tf.float32)
236
          normal = tf.tpu.cross_replica_sum(normal) / num_replicas_fp
237
        diff = decay * (moving - normal)
238
        return tf.assign_sub(moving, diff, use_locking=True, name=name)
239
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
240
                           u(moving_mean, mean, name='moving_mean'))
241
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
242
                           u(moving_variance, variance, name='moving_variance'))
243
      return x
244
    else:
245
      return x, mean, variance
246
  else:
247
    if params.use_tpu:
248
      x = tf.nn.batch_normalization(
249
          x, mean=moving_mean, variance=moving_variance, offset=beta,
250
          scale=gamma, variance_epsilon=params.batch_norm_epsilon)
251
    else:
252
      x, _, _ = tf.nn.fused_batch_norm(
253
          x, scale=gamma, offset=beta, mean=moving_mean,
254
          variance=moving_variance, epsilon=params.batch_norm_epsilon,
255
          is_training=False)
256

257
    if USE_BFLOAT16:
258
      x = tf.cast(x, tf.bfloat16)
259
    return x
260

261

262
def gpu_batch_norm(x, params, training=True, name='batch_norm'):
263
  """Async batch_norm."""
264
  shape = [x.shape[-1].value]
265
  with tf.variable_scope(name):
266
    gamma = get_variable('gamma', shape, tf.initializers.ones())
267
    beta = get_variable('beta', shape, tf.initializers.zeros())
268
    moving_mean = tf.get_variable(
269
        name='moving_mean',
270
        shape=shape,
271
        initializer=tf.initializers.zeros(),
272
        trainable=False,
273
        use_resource=True)
274
    moving_variance = tf.get_variable(
275
        name='moving_variance',
276
        shape=shape,
277
        initializer=tf.initializers.ones(),
278
        trainable=False,
279
        use_resource=True)
280

281
  if training:
282
    x, mean, variance = tf.nn.fused_batch_norm(
283
        x,
284
        scale=gamma,
285
        offset=beta,
286
        epsilon=params.batch_norm_epsilon,
287
        is_training=True,
288
    )
289

290
    def u(moving, normal):
291
      decay = tf.cast(1. - params.batch_norm_decay, tf.float32)
292
      return moving.assign_sub(decay * (moving - normal))
293

294
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u(moving_mean, mean))
295
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u(moving_variance, variance))
296
  else:
297
    x, _, _ = tf.nn.fused_batch_norm(
298
        x,
299
        scale=gamma,
300
        offset=beta,
301
        mean=moving_mean,
302
        variance=moving_variance,
303
        epsilon=params.batch_norm_epsilon,
304
        is_training=False,
305
    )
306
  return x
307

308

309
def wrn_block(x, params, num_out_filters, stride, training=True,
310
              name='wrn_block'):
311
  """WideResNet block."""
312

313
  with tf.variable_scope(name):
314
    num_inp_filters = x.shape[-1].value
315
    residual = x
316
    with tf.variable_scope('conv_3x3_1'):
317
      x = batch_norm(x, params, training)
318
      if stride == 2 or num_inp_filters != num_out_filters:
319
        residual = x
320
      x = relu(x)
321
      x = conv2d(x, 3, num_out_filters, stride)
322

323
    with tf.variable_scope('conv_3x3_2'):
324
      x = batch_norm(x, params, training)
325
      x = relu(x)
326
      x = conv2d(x, 3, num_out_filters, 1)
327

328
    with tf.variable_scope('residual'):
329
      if stride == 2 or num_inp_filters != num_out_filters:
330
        residual = relu(residual)
331
        residual = conv2d(residual, 1, num_out_filters, stride)
332

333
      x = x + residual
334
      log_tensor(x, True)
335
  return x
336

337

338
def resnet_block(x, params, num_out_filters, stride, training=True,
339
                 bottleneck_rate=4, name='resnet_block'):
340
  """ResNet-50 bottleneck block."""
341

342
  num_bottleneck_filters = num_out_filters // bottleneck_rate
343
  with tf.variable_scope(name):
344
    residual = x
345
    num_inp_filters = residual.shape[-1].value
346
    with tf.variable_scope('conv_1x1_1'):
347
      x = conv2d(x, 1, num_bottleneck_filters, 1)
348
      x = batch_norm(x, params, training)
349
      x = relu(x, leaky=0.)
350

351
    with tf.variable_scope('conv_3x3'):
352
      x = conv2d(x, 3, num_bottleneck_filters, stride)
353
      x = batch_norm(x, params, training)
354
      x = relu(x, leaky=0.)
355

356
    with tf.variable_scope('conv_1x1_2'):
357
      x = conv2d(x, 1, num_out_filters, 1)
358
      x = batch_norm(x, params, training)
359

360
    with tf.variable_scope('residual'):
361
      if stride == 2 or num_inp_filters != num_out_filters:
362
        residual = conv2d(residual, 1, num_out_filters, stride)
363
        residual = batch_norm(residual, params, training)
364

365
      x = relu(x + residual, leaky=0.)
366
      log_tensor(x, True)
367
  return x
368

369

370
def dropout(x, drop_rate, training):
371
  """Dropout."""
372
  if training:
373
    return tf.nn.dropout(x, rate=drop_rate)
374
  else:
375
    return x
376

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.