google-research

Форк
0
/
resnet.py 
625 строк · 23.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contains definitions for Residual Networks.
17

18
Residual networks ('v1' ResNets) were originally proposed in:
19
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
20
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
21

22
The full preactivation 'v2' ResNet variant was introduced by:
23
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24
    Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25

26
The key difference of the full preactivation 'v2' variant compared to the
27
'v1' variant in [1] is the use of batch normalization before every weight layer
28
rather than after.
29
"""
30

31
from __future__ import print_function
32
import tensorflow.compat.v1 as tf
33
from tensorflow.contrib import layers as contrib_layers
34

35
_BATCH_NORM_DECAY = 0.997
36
_BATCH_NORM_EPSILON = 1e-5
37
DEFAULT_VERSION = 1
38
DEFAULT_DTYPE = tf.float32
39
CASTABLE_TYPES = (tf.float16,)
40
ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES
41

42

43
def batch_norm(inputs, training, data_format):
44
  """Performs a batch normalization using a standard set of parameters."""
45
  # We set fused=True for a significant performance boost. See
46
  # https://www.tensorflow.org/performance/performance_guide#common_fused_ops
47
  return tf.compat.v1.layers.batch_normalization(
48
      inputs=inputs,
49
      axis=1 if data_format == 'channels_first' else 3,
50
      momentum=_BATCH_NORM_DECAY,
51
      epsilon=_BATCH_NORM_EPSILON,
52
      center=True,
53
      scale=True,
54
      training=training,
55
      fused=True)
56

57

58
def fixed_padding(inputs, kernel_size, data_format):
59
  """Pads the input along the spatial dimensions independently of input size.
60

61
  Args:
62
    inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
63
      height_in, width_in, channels] depending on data_format.
64
    kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
65
      Should be a positive integer.
66
    data_format: The input format ('channels_last' or 'channels_first').
67

68
  Returns:
69
    A tensor with the same format as the input with the data either intact
70
    (if kernel_size == 1) or padded (if kernel_size > 1).
71
  """
72
  pad_total = kernel_size - 1
73
  pad_beg = pad_total // 2
74
  pad_end = pad_total - pad_beg
75

76
  if data_format == 'channels_first':
77
    padded_inputs = tf.pad(
78
        tensor=inputs,
79
        paddings=[[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
80
  else:
81
    padded_inputs = tf.pad(
82
        tensor=inputs,
83
        paddings=[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
84
  return padded_inputs
85

86

87
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
88
  """Strided 2-D convolution with explicit padding."""
89
  # The padding is consistent and is based only on `kernel_size`, not on the
90
  # dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
91
  if strides > 1:
92
    inputs = fixed_padding(inputs, kernel_size, data_format)
93
  regu = contrib_layers.l2_regularizer(scale=0.0002)
94
  return tf.layers.conv2d(
95
      inputs=inputs,
96
      filters=filters,
97
      kernel_size=kernel_size,
98
      strides=strides,
99
      padding=('SAME' if strides == 1 else 'VALID'),
100
      use_bias=False,
101
      kernel_initializer=tf.compat.v1.variance_scaling_initializer(),
102
      kernel_regularizer=regu,
103
      data_format=data_format)
104

105

106
def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
107
                       data_format):
108
  """A single block for ResNet v1, without a bottleneck.
109

110
  Convolution then batch normalization then ReLU as described by:
111
    Deep Residual Learning for Image Recognition
112
    https://arxiv.org/pdf/1512.03385.pdf
113
    by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
114

115
  Args:
116
    inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
117
      height_in, width_in, channels] depending on data_format.
118
    filters: The number of filters for the convolutions.
119
    training: A Boolean for whether the model is in training or inference mode.
120
      Needed for batch normalization.
121
    projection_shortcut: The function to use for projection shortcuts (typically
122
      a 1x1 convolution when downsampling the input).
123
    strides: The block's stride. If greater than 1, this block will ultimately
124
      downsample the input.
125
    data_format: The input format ('channels_last' or 'channels_first').
126

127
  Returns:
128
    The output tensor of the block; shape should match inputs.
129
  """
130
  shortcut = inputs
131

132
  if projection_shortcut is not None:
133
    shortcut = projection_shortcut(inputs)
134
    shortcut = batch_norm(
135
        inputs=shortcut, training=training, data_format=data_format)
136

137
  inputs = conv2d_fixed_padding(
138
      inputs=inputs,
139
      filters=filters,
140
      kernel_size=3,
141
      strides=strides,
142
      data_format=data_format)
143
  inputs = batch_norm(inputs, training, data_format)
144
  inputs = tf.nn.relu(inputs)
145

146
  inputs = conv2d_fixed_padding(
147
      inputs=inputs,
148
      filters=filters,
149
      kernel_size=3,
150
      strides=1,
151
      data_format=data_format)
152
  inputs = batch_norm(inputs, training, data_format)
153
  inputs += shortcut
154
  inputs = tf.nn.relu(inputs)
155

156
  return inputs
157

158

159
def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
160
                       data_format):
161
  """A single block for ResNet v2, without a bottleneck.
162

163
  Batch normalization then ReLu then convolution as described by:
164
    Identity Mappings in Deep Residual Networks
165
    https://arxiv.org/pdf/1603.05027.pdf
166
    by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
167

168
  Args:
169
    inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
170
      height_in, width_in, channels] depending on data_format.
171
    filters: The number of filters for the convolutions.
172
    training: A Boolean for whether the model is in training or inference mode.
173
      Needed for batch normalization.
174
    projection_shortcut: The function to use for projection shortcuts (typically
175
      a 1x1 convolution when downsampling the input).
176
    strides: The block's stride. If greater than 1, this block will ultimately
177
      downsample the input.
178
    data_format: The input format ('channels_last' or 'channels_first').
179

180
  Returns:
181
    The output tensor of the block; shape should match inputs.
182
  """
183
  shortcut = inputs
184
  inputs = batch_norm(inputs, training, data_format)
185
  inputs = tf.nn.relu(inputs)
186

187
  # The projection shortcut should come after the first batch norm and ReLU
188
  # since it performs a 1x1 convolution.
189
  if projection_shortcut is not None:
190
    shortcut = projection_shortcut(inputs)
191

192
  inputs = conv2d_fixed_padding(
193
      inputs=inputs,
194
      filters=filters,
195
      kernel_size=3,
196
      strides=strides,
197
      data_format=data_format)
198

199
  inputs = batch_norm(inputs, training, data_format)
200
  inputs = tf.nn.relu(inputs)
201
  inputs = conv2d_fixed_padding(
202
      inputs=inputs,
203
      filters=filters,
204
      kernel_size=3,
205
      strides=1,
206
      data_format=data_format)
207

208
  return inputs + shortcut
209

210

211
def _bottleneck_block_v1(inputs, filters, training, projection_shortcut,
212
                         strides, data_format):
213
  """A single block for ResNet v1, with a bottleneck.
214

215
  Similar to _building_block_v1(), except using the "bottleneck" blocks
216
  described in:
217
    Convolution then batch normalization then ReLU as described by:
218
      Deep Residual Learning for Image Recognition
219
      https://arxiv.org/pdf/1512.03385.pdf
220
      by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
221

222
  Args:
223
    inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
224
      height_in, width_in, channels] depending on data_format.
225
    filters: The number of filters for the convolutions.
226
    training: A Boolean for whether the model is in training or inference mode.
227
      Needed for batch normalization.
228
    projection_shortcut: The function to use for projection shortcuts (typically
229
      a 1x1 convolution when downsampling the input).
230
    strides: The block's stride. If greater than 1, this block will ultimately
231
      downsample the input.
232
    data_format: The input format ('channels_last' or 'channels_first').
233

234
  Returns:
235
    The output tensor of the block; shape should match inputs.
236
  """
237
  shortcut = inputs
238

239
  if projection_shortcut is not None:
240
    shortcut = projection_shortcut(inputs)
241
    shortcut = batch_norm(
242
        inputs=shortcut, training=training, data_format=data_format)
243

244
  inputs = conv2d_fixed_padding(
245
      inputs=inputs,
246
      filters=filters,
247
      kernel_size=1,
248
      strides=1,
249
      data_format=data_format)
250
  inputs = batch_norm(inputs, training, data_format)
251
  inputs = tf.nn.relu(inputs)
252

253
  inputs = conv2d_fixed_padding(
254
      inputs=inputs,
255
      filters=filters,
256
      kernel_size=3,
257
      strides=strides,
258
      data_format=data_format)
259
  inputs = batch_norm(inputs, training, data_format)
260
  inputs = tf.nn.relu(inputs)
261

262
  inputs = conv2d_fixed_padding(
263
      inputs=inputs,
264
      filters=4 * filters,
265
      kernel_size=1,
266
      strides=1,
267
      data_format=data_format)
268
  inputs = batch_norm(inputs, training, data_format)
269
  inputs += shortcut
270
  inputs = tf.nn.relu(inputs)
271

272
  return inputs
273

274

275
def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
276
                         strides, data_format):
277
  """A single block for ResNet v2, with a bottleneck.
278

279
  Similar to _building_block_v2(), except using the "bottleneck" blocks
280
  described in:
281
    Convolution then batch normalization then ReLU as described by:
282
      Deep Residual Learning for Image Recognition
283
      https://arxiv.org/pdf/1512.03385.pdf
284
      by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
285

286
  Adapted to the ordering conventions of:
287
    Batch normalization then ReLu then convolution as described by:
288
      Identity Mappings in Deep Residual Networks
289
      https://arxiv.org/pdf/1603.05027.pdf
290
      by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
291

292
  Args:
293
    inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
294
      height_in, width_in, channels] depending on data_format.
295
    filters: The number of filters for the convolutions.
296
    training: A Boolean for whether the model is in training or inference mode.
297
      Needed for batch normalization.
298
    projection_shortcut: The function to use for projection shortcuts (typically
299
      a 1x1 convolution when downsampling the input).
300
    strides: The block's stride. If greater than 1, this block will ultimately
301
      downsample the input.
302
    data_format: The input format ('channels_last' or 'channels_first').
303

304
  Returns:
305
    The output tensor of the block; shape should match inputs.
306
  """
307
  shortcut = inputs
308
  inputs = batch_norm(inputs, training, data_format)
309
  inputs = tf.nn.relu(inputs)
310

311
  # The projection shortcut should come after the first batch norm and ReLU
312
  # since it performs a 1x1 convolution.
313
  if projection_shortcut is not None:
314
    shortcut = projection_shortcut(inputs)
315

316
  inputs = conv2d_fixed_padding(
317
      inputs=inputs,
318
      filters=filters,
319
      kernel_size=1,
320
      strides=1,
321
      data_format=data_format)
322

323
  inputs = batch_norm(inputs, training, data_format)
324
  inputs = tf.nn.relu(inputs)
325
  inputs = conv2d_fixed_padding(
326
      inputs=inputs,
327
      filters=filters,
328
      kernel_size=3,
329
      strides=strides,
330
      data_format=data_format)
331

332
  inputs = batch_norm(inputs, training, data_format)
333
  inputs = tf.nn.relu(inputs)
334
  inputs = conv2d_fixed_padding(
335
      inputs=inputs,
336
      filters=4 * filters,
337
      kernel_size=1,
338
      strides=1,
339
      data_format=data_format)
340

341
  return inputs + shortcut
342

343

344
def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides,
345
                training, name, data_format):
346
  """Creates one layer of blocks for the ResNet model.
347

348
  Args:
349
    inputs: A tensor of size [batch, channels, height_in, width_in] or [batch,
350
      height_in, width_in, channels] depending on data_format.
351
    filters: The number of filters for the first convolution of the layer.
352
    bottleneck: Is the block created a bottleneck block.
353
    block_fn: The block to use within the model, either `building_block` or
354
      `bottleneck_block`.
355
    blocks: The number of blocks contained in the layer.
356
    strides: The stride to use for the first convolution of the layer. If
357
      greater than 1, this layer will ultimately downsample the input.
358
    training: Either True or False, whether we are currently training the model.
359
      Needed for batch norm.
360
    name: A string name for the tensor output of the block layer.
361
    data_format: The input format ('channels_last' or 'channels_first').
362

363
  Returns:
364
    The output tensor of the block layer.
365
  """
366

367
  # Bottleneck blocks end with 4x the number of filters as they start with
368
  filters_out = filters * 4 if bottleneck else filters
369

370
  def projection_shortcut(inputs):
371
    return conv2d_fixed_padding(
372
        inputs=inputs,
373
        filters=filters_out,
374
        kernel_size=1,
375
        strides=strides,
376
        data_format=data_format)
377

378
  # Only the first block per block_layer uses projection_shortcut and strides
379
  inputs = block_fn(inputs, filters, training, projection_shortcut, strides,
380
                    data_format)
381

382
  for _ in range(1, blocks):
383
    inputs = block_fn(inputs, filters, training, None, 1, data_format)
384

385
  return tf.identity(inputs, name)
386

387

388
class Model(object):
389
  """Base class for building the Resnet Model."""
390

391
  def __init__(self,
392
               wd,
393
               resnet_size,
394
               bottleneck,
395
               num_classes,
396
               num_filters,
397
               kernel_size,
398
               conv_stride,
399
               first_pool_size,
400
               first_pool_stride,
401
               block_sizes,
402
               block_strides,
403
               feature_dim,
404
               resnet_version=DEFAULT_VERSION,
405
               data_format=None,
406
               dtype=DEFAULT_DTYPE):
407
    """Creates a model for classifying an image.
408

409
    Args:
410
      wd: The co-efficient of weight decay.
411
      resnet_size: A single integer for the size of the ResNet model.
412
      bottleneck: Use regular blocks or bottleneck blocks.
413
      num_classes: The number of classes used as labels.
414
      num_filters: The number of filters to use for the first block layer of the
415
        model. This number is then doubled for each subsequent block layer.
416
      kernel_size: The kernel size to use for convolution.
417
      conv_stride: stride size for the initial convolutional layer
418
      first_pool_size: Pool size to be used for the first pooling layer. If
419
        none, the first pooling layer is skipped.
420
      first_pool_stride: stride size for the first pooling layer. Not used if
421
        first_pool_size is None.
422
      block_sizes: A list containing n values, where n is the number of sets of
423
        block layers desired. Each value should be the number of blocks in the
424
        i-th set.
425
      block_strides: List of integers representing the desired stride size for
426
        each of the sets of block layers. Should be same length as block_sizes.
427
      feature_dim: the dimension of the representation space.
428
      resnet_version: Integer representing which version of the ResNet network
429
        to use. See README for details. Valid values: [1, 2]
430
      data_format: Input format ('channels_last', 'channels_first', or None). If
431
        set to None, the format is dependent on whether a GPU is available.
432
      dtype: The TensorFlow dtype to use for calculations. If not specified
433
        tf.float32 is used.
434

435
    Raises:
436
      ValueError: if invalid version is selected.
437
    """
438
    self.resnet_size = resnet_size
439

440
    if not data_format:
441
      data_format = ('channels_first'
442
                     if tf.test.is_built_with_cuda() else 'channels_last')
443

444
    self.resnet_version = resnet_version
445
    if resnet_version not in (1, 2):
446
      raise ValueError(
447
          'Resnet version should be 1 or 2. See README for citations.')
448

449
    self.bottleneck = bottleneck
450
    if bottleneck:
451
      if resnet_version == 1:
452
        self.block_fn = _bottleneck_block_v1
453
      else:
454
        self.block_fn = _bottleneck_block_v2
455
    else:
456
      if resnet_version == 1:
457
        self.block_fn = _building_block_v1
458
      else:
459
        self.block_fn = _building_block_v2
460

461
    if dtype not in ALLOWED_TYPES:
462
      raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES))
463

464
    self.data_format = data_format
465
    self.num_classes = num_classes
466
    self.num_filters = num_filters
467
    self.kernel_size = kernel_size
468
    self.conv_stride = conv_stride
469
    self.first_pool_size = first_pool_size
470
    self.first_pool_stride = first_pool_stride
471
    self.block_sizes = block_sizes
472
    self.block_strides = block_strides
473
    self.dtype = dtype
474
    self.pre_activation = resnet_version == 2
475
    self.regularizer = contrib_layers.l2_regularizer(scale=wd)
476
    self.initializer = contrib_layers.xavier_initializer()
477
    self.drop_rate = 0.5
478
    self.feature_dim = feature_dim
479

480
  def _custom_dtype_getter(self,
481
                           getter,
482
                           name,
483
                           shape=None,
484
                           dtype=DEFAULT_DTYPE,
485
                           *args,
486
                           **kwargs):
487
    """Creates variables in fp32, then casts to fp16 if necessary.
488

489
    This function is a custom getter. A custom getter is a function with the
490
    same signature as tf.get_variable, except it has an additional getter
491
    parameter. Custom getters can be passed as the `custom_getter` parameter of
492
    tf.variable_scope. Then, tf.get_variable will call the custom getter,
493
    instead of directly getting a variable itself. This can be used to change
494
    the types of variables that are retrieved with tf.get_variable.
495
    The `getter` parameter is the underlying variable getter, that would have
496
    been called if no custom getter was used. Custom getters typically get a
497
    variable with `getter`, then modify it in some way.
498

499
    This custom getter will create an fp32 variable. If a low precision
500
    (e.g. float16) variable was requested it will then cast the variable to the
501
    requested dtype. The reason we do not directly create variables in low
502
    precision dtypes is that applying small gradients to such variables may
503
    cause the variable not to change.
504

505
    Args:
506
      getter: The underlying variable getter, that has the same signature as
507
        tf.get_variable and returns a variable.
508
      name: The name of the variable to get.
509
      shape: The shape of the variable to get.
510
      dtype: The dtype of the variable to get. Note that if this is a low
511
        precision dtype, the variable will be created as a tf.float32 variable,
512
        then cast to the appropriate dtype
513
      *args: Additional arguments to pass unmodified to getter.
514
      **kwargs: Additional keyword arguments to pass unmodified to getter.
515

516
    Returns:
517
      A variable which is cast to fp16 if necessary.
518
    """
519
    # pylint: disable=keyword-arg-before-vararg
520
    if dtype in CASTABLE_TYPES:
521
      var = getter(name, shape, tf.float32, *args, **kwargs)
522
      return tf.cast(var, dtype=dtype, name=name + '_cast')
523
    else:
524
      return getter(name, shape, dtype, *args, **kwargs)
525

526
  def _model_variable_scope(self):
527
    """Returns a variable scope that the model should be created under.
528

529
    If self.dtype is a castable type, model variable will be created in fp32
530
    then cast to self.dtype before being used.
531

532
    Returns:
533
      A variable scope for the model.
534
    """
535

536
    return tf.compat.v1.variable_scope(
537
        'resnet_model', custom_getter=self._custom_dtype_getter)
538

539
  def confidence_model(self, mu, training):
540
    """Given a batch of mu, output a batch of variance."""
541
    out = tf.layers.dropout(mu, rate=self.drop_rate, training=training)
542
    out = tf.layers.dense(out, units=self.feature_dim, \
543
                          kernel_initializer=self.initializer, \
544
                          kernel_regularizer=self.regularizer, \
545
                          name='fc_variance')
546
    out = tf.nn.relu(out)
547
    out = tf.layers.batch_normalization(out, training=training, \
548
                                        name='fc_variance_bn')
549
    out = tf.layers.dropout(out, rate=self.drop_rate, training=training)
550
    out = tf.layers.dense(out, units=self.feature_dim, \
551
                          kernel_initializer=self.initializer, \
552
                          kernel_regularizer=self.regularizer, \
553
                          name='fc_variance2')
554
    return out
555

556
  def encoder(self, inputs, training):
557
    """Add operations to classify a batch of input images.
558

559
    Args:
560
      inputs: A Tensor representing a batch of input images.
561
      training: A boolean. Set to True to add operations required only when
562
        training the classifier.
563

564
    Returns:
565
      A logits Tensor with shape [<batch_size>, self.num_classes].
566
    """
567

568
    with self._model_variable_scope():
569
      if self.data_format == 'channels_first':
570
        # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
571
        # This provides a large performance boost on GPU. See
572
        # https://www.tensorflow.org/performance/performance_guide#data_formats
573
        inputs = tf.transpose(a=inputs, perm=[0, 3, 1, 2])
574

575
      inputs = conv2d_fixed_padding(
576
          inputs=inputs,
577
          filters=self.num_filters,
578
          kernel_size=self.kernel_size,
579
          strides=self.conv_stride,
580
          data_format=self.data_format)
581
      inputs = tf.identity(inputs, 'initial_conv')
582

583
      if self.resnet_version == 1:
584
        inputs = batch_norm(inputs, training, self.data_format)
585
        inputs = tf.nn.relu(inputs)
586

587
      if self.first_pool_size:
588
        inputs = tf.compat.v1.layers.max_pooling2d(
589
            inputs=inputs,
590
            pool_size=self.first_pool_size,
591
            strides=self.first_pool_stride,
592
            padding='SAME',
593
            data_format=self.data_format)
594
        inputs = tf.identity(inputs, 'initial_max_pool')
595

596
      for i, num_blocks in enumerate(self.block_sizes):
597
        num_filters = self.num_filters * (2**i)
598
        inputs = block_layer(
599
            inputs=inputs,
600
            filters=num_filters,
601
            bottleneck=self.bottleneck,
602
            block_fn=self.block_fn,
603
            blocks=num_blocks,
604
            strides=self.block_strides[i],
605
            training=training,
606
            name='block_layer{}'.format(i + 1),
607
            data_format=self.data_format)
608

609
      # Only apply the BN and ReLU for model that does pre_activation in each
610
      # building/bottleneck block, eg resnet V2.
611
      # if self.pre_activation:
612
      #   inputs = batch_norm(inputs, training, self.data_format)
613
      #   inputs = tf.nn.relu(inputs)
614

615
      # The current top layer has shape
616
      # `batch_size x pool_size x pool_size x final_size`.
617
      # ResNet does an Average Pooling layer over pool_size,
618
      # but that is the same as doing a reduce_mean. We do a reduce_mean
619
      # here because it performs better than AveragePooling2D.
620
      axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
621
      inputs = tf.reduce_mean(input_tensor=inputs, axis=axes, keepdims=True)
622
      inputs = tf.identity(inputs, 'final_reduce_mean')
623

624
      inputs = tf.squeeze(inputs, axes)
625
      return inputs
626

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.