google-research

Форк
0
619 строк · 21.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=logging-format-interpolation
17
# pylint: disable=g-direct-tensorflow-import
18

19
r"""Common utils."""
20

21
import os
22
import re
23
import threading
24

25
from absl import logging
26
import numpy as np
27
import tensorflow.compat.v1 as tf
28

29

30
def get_worker_name(worker_id):
31
  """Returns `/job:tpu_worker/task:{worker_id}`."""
32
  return f'/job:tpu_worker/task:{worker_id}'
33

34

35
def get_device_name(worker_id, core_id):
36
  """Returns `/job:tpu_worker/task:{worker_id}/device:tpu:{core_id}`."""
37
  return f'/job:tpu_worker/task:{worker_id}/device:TPU:{core_id}'
38

39

40
def count_params():
41
  """Count model params."""
42
  num_params = sum([np.prod([d.value for d in w.shape])
43
                    for w in tf.trainable_variables()
44
                    if 'teacher' not in w.name.lower()])
45
  return num_params
46

47

48
def strip_var_name(var_name):
49
  """Strips variable name of sub-strings blocking variable name matching.
50

51
  Removes sub-strings that should be ignored when matching checkpointed variable
52
  names to variable names in the training graph, namely:
53
  - trailing colon + number, e.g. "W:0" --> "W"
54
  - partitioning info., e.g. "/a/part_12/b" --> "a/b".
55
  (Note that checkpointed variables do not have partitioning info in their name,
56
  while model variables do).
57

58
  Args:
59
    var_name: str, variable name.
60

61
  Returns:
62
    stripped variable name.
63
  """
64
  # Strip trailing number, e.g. convert "lstm/W_0:0" to "lstm/W_0".
65
  var_name = re.sub(r':\d+$', '', var_name)
66
  # Strip partitioning info, e.g. convert "W_0/part_3/Adagrad" to "W_0/Adagrad".
67
  var_name = re.sub(r'/part_\d+', '', var_name)
68
  return var_name
69

70

71
def get_saver(max_to_keep=1, restore_ema=False):
72
  """Constructs a `Saver`."""
73
  var_list = {}
74
  if restore_ema:
75
    logging.info('Restore EMA values')
76
    for v in tf.global_variables():
77
      if v.name.startswith('ema'):
78
        logging.fatal(f'wrong ema var name `{v.name}`')
79
      if 'global_step' in v.name:
80
        var_list['global_step'] = v
81
      else:
82
        var_list['ema/' + strip_var_name(v.name)] = v
83
  else:
84
    for v in tf.global_variables():
85
      var_list[strip_var_name(v.name)] = v
86
  saver = tf.train.Saver(var_list,
87
                         max_to_keep=max_to_keep,
88
                         save_relative_paths=True)
89
  return saver
90

91

92
class AsyncCheckpoint(object):
93
  """Saves checkpoint using a separated thread."""
94

95
  def __init__(self, saver, ckpt_dir, max_to_keep=None):
96
    self._saver = saver
97
    self._ckpt_dir = ckpt_dir
98
    self._max_to_keep = max_to_keep
99
    self._thread = None
100
    self.latest_checkpoint = None
101

102
  def join(self):
103
    if self._thread is not None:
104
      self._thread.join()
105

106
  def save(self, sess, step):
107
    """Docs."""
108

109
    def _save_fn():
110
      """Run the saver process."""
111
      raw_sess = sess if isinstance(sess, tf.Session) else sess.raw_session()
112
      ckpt_path = self._saver.save(
113
          raw_sess,
114
          save_path=os.path.join(self._ckpt_dir, 'ckpt'),
115
          global_step=step,
116
          write_meta_graph=False,
117
          write_state=False)
118
      self.latest_checkpoint = ckpt_path[len(self._ckpt_dir) + 1:]
119
      logging.info(f'Saved checkpoint `{ckpt_path}`')
120

121
      all_checkpoints = get_all_checkpoints(self._ckpt_dir)
122
      assert all_checkpoints is not None
123
      new_ckpt_content = [f'model_checkpoint_path: "{all_checkpoints[-1]}"']
124
      if (self._max_to_keep is not None and
125
          self._max_to_keep < len(all_checkpoints)):
126
        pattern = all_checkpoints[0] + '*'
127
        tf.io.gfile.BulkDelete(tf.io.gfile.Glob(pattern))
128
        # pylint: disable=invalid-unary-operand-type
129
        all_checkpoints = all_checkpoints[-self._max_to_keep:]
130
        # pylint: enable=invalid-unary-operand-type
131
      for ckpt_name in all_checkpoints:
132
        new_ckpt_content.append(f'all_model_checkpoint_paths: "{ckpt_name}"')
133
      checkpoint_file = os.path.join(self._ckpt_dir, 'checkpoint')
134
      with tf.io.gfile.GFile(checkpoint_file, 'w') as fout:
135
        fout.write('\n'.join(new_ckpt_content))
136

137
    if self._thread is not None:
138
      self._thread.join(timeout=0.1)
139
      if self._thread.is_alive():
140
        logging.info('Saver thread still in progress, skipping checkpoint.')
141
        return
142

143
    self._thread = threading.Thread(target=_save_fn)
144
    self._thread.start()
145

146

147
def should_log(params):
148
  """Returns a Boolean `tf.Tensor` dictating whether we should log values."""
149
  global_step = tf.train.get_or_create_global_step()
150
  first_run = tf.equal(global_step, 1)
151
  log_every = tf.equal(tf.floormod(global_step, params.log_every), 0)
152
  return tf.logical_or(first_run, log_every)
153

154

155
def get_all_checkpoints(ckpt_dir):
156
  """Returns a list of all checkpoints, eg `['ckpt-100', 'ckpt-500']`."""
157
  if not tf.io.gfile.IsDirectory(ckpt_dir):
158
    return []
159
  pattern = ckpt_dir + '/ckpt-*'
160
  s = len(ckpt_dir) + len('/ckpt-')
161
  checkpoints = [int(f.split('.')[0][s:]) for f in tf.io.gfile.Glob(pattern)]
162
  checkpoints = [os.path.join(ckpt_dir, 'ckpt-{0}'.format(v))
163
                 for v in sorted(set(checkpoints))]
164
  return checkpoints
165

166

167
def get_latest_checkpoint(ckpt_dir):
168
  """Returns a list of all checkpoints, eg `['ckpt-100', 'ckpt-500']`."""
169
  all_checkpoints = get_all_checkpoints(ckpt_dir)
170
  all_checkpoints = [ckpt for ckpt in all_checkpoints if 'temp' not in ckpt]
171
  if all_checkpoints:
172
    return all_checkpoints[-1]
173
  else:
174
    return None
175

176

177
def get_outfeed_ops(params, signature):
178
  """Create TPU outfeed ops."""
179
  outfeed_dtypes, outfeed_shapes = [], []
180
  for dtype, shape in signature.values():
181
    outfeed_dtypes.append(dtype)
182
    outfeed_shapes.append(shape)
183
  outfeed_ops = []
184
  outfeed_graph = tf.Graph()
185

186
  dev_assign = params.device_assignment
187
  host_to_tpus = {}
188
  for replica_id in range(params.num_replicas):
189
    host_device = dev_assign.host_device(replica=replica_id, logical_core=0)
190
    tpu_ordinal = dev_assign.tpu_ordinal(replica=replica_id, logical_core=0)
191
    if host_device not in host_to_tpus:
192
      host_to_tpus[host_device] = [tpu_ordinal]
193
    else:
194
      assert tpu_ordinal not in host_to_tpus[host_device]
195
      host_to_tpus[host_device].append(tpu_ordinal)
196

197
  with outfeed_graph.as_default():
198
    for host, tpus in host_to_tpus.items():
199
      with tf.device(host):
200
        for device_ordinal in tpus:
201
          device_outfeed = tf.raw_ops.OutfeedDequeueTuple(
202
              dtypes=outfeed_dtypes,
203
              shapes=outfeed_shapes,
204
              device_ordinal=device_ordinal)
205
          outfeed_ops.append(device_outfeed)
206

207
  return outfeed_ops, outfeed_graph
208

209

210
class InfeedThread(object):
211
  """InfeedTread wrapper."""
212

213
  def __init__(self, params, infeed_ops, infeed_graphs, name='infeed_thread'):
214
    if infeed_graphs is not None:
215
      assert isinstance(infeed_graphs, list)
216
      assert len(infeed_graphs) == len(infeed_ops)
217

218
    self.infeed_ops = infeed_ops
219
    self.infeed_graphs = infeed_graphs
220

221
    self.sessions = []
222
    for g in infeed_graphs:
223
      with g.as_default():
224
        sess = tf.Session(target=params.master, graph=g)
225
        self.sessions.append(sess)
226

227
    self.name = name
228
    self._threads = []
229

230
  def stop(self):
231
    self.join()
232
    for sess in self.sessions:
233
      sess.close()
234

235
  def join(self):
236
    for thread in self._threads:
237
      if thread is not None:
238
        thread.join(timeout=0.1)
239
        del thread
240

241
  def start(self, verbose=False):
242
    """Docs."""
243
    if verbose:
244
      logging.info(f'Start thread for `{self.name}`')
245

246
    def _infeed_fn(sess, infeed_op, infeed_graph):
247
      """Run the infeed process."""
248
      with infeed_graph.as_default():
249
        sess.run(infeed_op)
250

251
    for sess, op, g in zip(self.sessions, self.infeed_ops, self.infeed_graphs):
252
      thread = threading.Thread(target=_infeed_fn, args=(sess, op, g))
253
      thread.daemon = True
254
      thread.start()
255
      self._threads.append(thread)
256

257

258
class OutfeedThread(object):
259
  """OutfeedThread wrapper."""
260

261
  def __init__(self, params, outfeed_ops, outfeed_graph, outfeed_signature,
262
               name='outfeed_thread'):
263
    self.params = params
264
    self.outfeed_ops = outfeed_ops
265
    self.outfeed_graph = outfeed_graph
266
    self.outfeed_signature = outfeed_signature
267

268
    with outfeed_graph.as_default():
269
      self.session = tf.Session(target=params.master, graph=outfeed_graph)
270

271
    self.name = name
272
    self._thread = None
273

274
  def join(self):
275
    if self._thread is not None:
276
      self._thread.join(timeout=0.1)
277
      self._thread = None
278
    self.session.close()
279

280
  def start(self, verbose=False):
281
    """Docs."""
282
    if verbose:
283
      logging.info(f'Start thread for `{self.name}`')
284
    if self._thread is not None:
285
      return
286

287
    params = self.params
288
    outfeed_signature = self.outfeed_signature
289

290
    def _outfeed_fn():
291
      """Read from `outfeed_dequeue` and write `Summary`."""
292
      train_logdir = os.path.join(params.output_dir, 'logs', 'train')
293
      summary_writer = tf.summary.FileWriter(train_logdir)
294
      summary_tags = list(outfeed_signature.keys())
295
      while True:
296
        outfeeds = self.session.run(self.outfeed_ops)
297
        outfeeds = np.array(outfeeds).reshape([params.num_replicas, -1])
298
        outfeeds = np.sum(outfeeds, axis=0).tolist()
299
        summary_values = []
300
        for tag, value in zip(summary_tags, outfeeds):
301
          if tag == 'global_step':
302
            value /= params.num_replicas
303
            step = value
304
          else:
305
            summary_values.append(tf.Summary.Value(tag=tag, simple_value=value))
306
        summary_writer.add_summary(tf.Summary(value=summary_values), step)
307
        summary_writer.flush()
308
        if step >= params.num_train_steps:
309
          summary_writer.close()
310
          break
311

312
    self._thread = threading.Thread(target=_outfeed_fn)
313
    self._thread.daemon = True
314
    self._thread.start()
315

316

317
def setup_ema(params, name_scope=None):
318
  """Create exponential moving average for all variables under `name_scope`."""
319
  logging.info(f'ema_decay with rate {params.ema_decay}')
320
  all_vars = tf.global_variables()
321
  ema_ops = []
322
  step = tf.cast(tf.train.get_or_create_global_step() - params.ema_start,
323
                 tf.float32)
324
  decay = 1. - tf.minimum(params.ema_decay, (step+1.) / (step+10.))
325
  decay = tf.cond(tf.train.get_or_create_global_step() < params.ema_start,
326
                  lambda: tf.constant(1, tf.float32), lambda: decay)
327

328
  def should_skip(v):
329
    key_words = ['momentum', 'rms', 'global_step', 'debug', 'adam', 'lars']
330
    conditions = [k in v.name.lower() for k in key_words]
331
    if name_scope is not None:
332
      conditions += [not v.name.lower().startswith(name_scope)]
333
    return any(conditions)
334

335
  def get_init(v_name):
336
    key_words = ['variance', 'beta']
337
    if any([k in v_name for k in key_words]):
338
      return tf.initializers.ones()
339
    return tf.initializers.zeros()
340

341
  with tf.variable_scope('ema'):
342
    for v in all_vars:
343
      if not should_skip(v):
344
        v_name = strip_var_name(v.name)
345
        with tf.device(v.device):
346
          ema_var = tf.get_variable(
347
              name=v_name,
348
              shape=v.shape.as_list(),
349
              initializer=get_init(v_name),
350
              trainable=False)
351
          ema_op = tf.assign_sub(ema_var, decay * (ema_var-v), use_locking=True)
352
        ema_ops.append(ema_op)
353
  ema_op = tf.group(*ema_ops)
354
  return ema_op
355

356

357
def get_session(params, isolate_session_state=True):
358
  """Builds and returns a `tf.Session`."""
359
  config = tf.ConfigProto(
360
      isolate_session_state=isolate_session_state,
361
      allow_soft_placement=True,
362
      graph_options=tf.GraphOptions(
363
          optimizer_options=tf.OptimizerOptions(
364
              opt_level=tf.OptimizerOptions.L0,
365
              do_common_subexpression_elimination=False,
366
              do_function_inlining=False,
367
              do_constant_folding=False)))
368
  return tf.Session(target=params.master, config=config)
369

370

371
def get_learning_rate(params, initial_lr=None, num_warmup_steps=None,
372
                      num_wait_steps=None):
373
  """Build learning rate."""
374
  global_step = tf.train.get_or_create_global_step()
375

376
  if initial_lr is None:
377
    initial_lr = params.lr
378
  initial_lr = initial_lr * params.train_batch_size / 256.
379

380
  if num_warmup_steps is None:
381
    num_warmup_steps = params.num_warmup_steps
382

383
  if num_wait_steps is not None:
384
    global_step = global_step - num_wait_steps
385

386
  if params.lr_decay_type == 'constant':
387
    lr = tf.constant(initial_lr, dtype=tf.float32)
388
  elif params.lr_decay_type == 'exponential':
389
    lr = tf.train.exponential_decay(
390
        learning_rate=initial_lr,
391
        global_step=global_step-num_warmup_steps,
392
        decay_steps=params.num_decay_steps,
393
        decay_rate=params.lr_decay_rate,
394
        staircase=True)
395
  elif params.lr_decay_type == 'cosine':
396
    if num_wait_steps is None:
397
      lr = tf.train.cosine_decay(
398
          learning_rate=initial_lr,
399
          global_step=global_step-num_warmup_steps,
400
          decay_steps=params.num_train_steps-num_warmup_steps,
401
          alpha=0.0)
402
    else:
403
      lr = tf.train.cosine_decay(
404
          learning_rate=initial_lr,
405
          global_step=global_step-num_warmup_steps,
406
          decay_steps=params.num_train_steps-num_warmup_steps-num_wait_steps,
407
          alpha=0.0)
408
  else:
409
    raise ValueError(f'Unknown lr_decay_type `{params.lr_decay_type}`')
410

411
  r = (tf.cast(global_step+1, tf.float32) /
412
       tf.cast(num_warmup_steps, tf.float32))
413
  warmup_lr = initial_lr * r
414
  lr = tf.cond(global_step < num_warmup_steps, lambda: warmup_lr, lambda: lr)
415

416
  if num_wait_steps is not None:
417
    lr = tf.cond(global_step < 0,
418
                 lambda: tf.constant(0., tf.float32), lambda: lr)
419

420
  return lr
421

422

423
def get_optimizer(params, learning_rate=None):
424
  """Build optimizer."""
425
  if learning_rate is None:
426
    learning_rate = get_learning_rate(params)
427

428
  if params.optim_type.lower() == 'sgd':
429
    logging.info('Use SGD')
430
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate,
431
                                                  use_locking=True)
432
  elif params.optim_type.lower() == 'momentum':
433
    logging.info('Use Momentum')
434
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
435
                                           momentum=0.9,
436
                                           use_nesterov=True,
437
                                           use_locking=True)
438
  elif params.optim_type.lower() == 'rmsprop':
439
    optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,
440
                                          decay=params.rmsprop_rho,
441
                                          momentum=params.rmsprop_momentum,
442
                                          epsilon=params.rmsprop_epsilon,
443
                                          use_locking=True)
444
  elif params.optim_type.lower() == 'lars':
445
    class LARSOptimizer(tf.train.Optimizer):
446
      """Layer-wise Adaptive Rate Scaling for large batch training.
447

448
      Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
449
      I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
450

451
      Implements the LARS learning rate scheme presented in the paper above.
452
      This optimizer is useful when scaling the batch size to up to 32K without
453
      significant performance degradation. It is recommended to use the
454
      optimizer in conjunction with:
455
          - Gradual learning rate warm-up
456
          - Linear learning rate scaling
457
          - Poly rule learning rate decay
458

459
      Note, LARS scaling is currently only enabled for dense tensors. Sparse
460
      tensors use the default momentum optimizer.
461
      """
462

463
      def __init__(
464
          self,
465
          learning_rate,
466
          momentum=0.9,
467
          weight_decay=0.0001,
468
          # The LARS coefficient is a hyperparameter
469
          eeta=0.001,
470
          epsilon=0.0,
471
          name='LARSOptimizer',
472
          # Enable skipping variables from LARS scaling.
473
          # TODO(sameerkm): Enable a direct mechanism to pass a
474
          # subset of variables to the optimizer.
475
          skip_list=None,
476
          use_nesterov=False):
477
        """Construct a new LARS Optimizer.
478

479
        Args:
480
          learning_rate: A `Tensor` or floating point value.
481
          momentum: A floating point value. Momentum hyperparameter.
482
          weight_decay: A floating point value. Weight decay hyperparameter.
483
          eeta: LARS coefficient as used in the paper. Dfault set to LARS
484
            coefficient from the paper. (eeta / weight_decay) determines the
485
            highest scaling factor in LARS.
486
          epsilon: Optional epsilon parameter to be set in models that have very
487
            small gradients. Default set to 0.0.
488
          name: Optional name prefix for variables and ops created.
489
          skip_list: List of strings to enable skipping variables from scaling.
490
            If any of the strings in skip_list is a subset of var.name, variable
491
            'var' is skipped from LARS scaling. For a typical classification
492
            model with batch normalization, the skip_list is
493
            ['batch_normalization', 'bias']
494
          use_nesterov: when set to True, nesterov momentum will be enabled
495

496
        Raises:
497
          ValueError: If a hyperparameter is set to a non-sensical value.
498
        """
499
        if momentum < 0.0:
500
          raise ValueError(f'momentum should be positive: {momentum}')
501
        if weight_decay < 0.0:
502
          raise ValueError(f'weight_decay should be positive: {weight_decay}')
503
        super(LARSOptimizer, self).__init__(use_locking=False, name=name)
504

505
        self._learning_rate = learning_rate
506
        self._momentum = momentum
507
        self._weight_decay = weight_decay
508
        self._eeta = eeta
509
        self._epsilon = epsilon
510
        self._name = name
511
        self._skip_list = skip_list
512
        self._use_nesterov = use_nesterov
513

514
      def _create_slots(self, var_list):
515
        for v in var_list:
516
          self._zeros_slot(v, 'momentum', self._name)
517

518
      def compute_lr(self, grad, var):
519
        scaled_lr = self._learning_rate
520
        if self._skip_list is None or not any(v in var.name
521
                                              for v in self._skip_list):
522
          w_norm = tf.norm(var, ord=2)
523
          g_norm = tf.norm(grad, ord=2)
524
          trust_ratio = tf.where(
525
              tf.math.greater(w_norm, 0),
526
              tf.where(
527
                  tf.math.greater(g_norm, 0),
528
                  (self._eeta * w_norm / (
529
                      g_norm + self._weight_decay * w_norm + self._epsilon)),
530
                  1.0),
531
              1.0)
532
          scaled_lr = self._learning_rate * trust_ratio
533
          # Add the weight regularization gradient
534
          grad = grad + self._weight_decay * var
535
        return scaled_lr, grad
536

537
      def _apply_dense(self, grad, var):
538
        scaled_lr, grad = self.compute_lr(grad, var)
539
        mom = self.get_slot(var, 'momentum')
540
        return tf.raw_ops.ApplyMomentum(
541
            var,
542
            mom,
543
            tf.cast(1.0, var.dtype.base_dtype),
544
            grad * scaled_lr,
545
            self._momentum,
546
            use_locking=False,
547
            use_nesterov=self._use_nesterov)
548

549
      def _resource_apply_dense(self, grad, var):
550
        scaled_lr, grad = self.compute_lr(grad, var)
551
        mom = self.get_slot(var, 'momentum')
552
        return tf.raw_ops.ResourceApplyMomentum(
553
            var=var.handle,
554
            accum=mom.handle,
555
            lr=tf.cast(1.0, var.dtype.base_dtype),
556
            grad=grad * scaled_lr,
557
            momentum=self._momentum,
558
            use_locking=False,
559
            use_nesterov=self._use_nesterov)
560

561
      # Fallback to momentum optimizer for sparse tensors
562
      def _apply_sparse(self, grad, var):
563
        mom = self.get_slot(var, 'momentum')
564
        return tf.raw_ops.SparseApplyMomentum(
565
            var,
566
            mom,
567
            tf.cast(self._learning_rate_tensor, var.dtype.base_dtype),
568
            grad.values,
569
            grad.indices,
570
            tf.cast(self._momentum_tensor, var.dtype.base_dtype),
571
            use_locking=self._use_locking,
572
            use_nesterov=self._use_nesterov).op
573

574
      def _resource_apply_sparse(self, grad, var, indices):
575
        mom = self.get_slot(var, 'momentum')
576
        return tf.raw_ops.ResourceSparseApplyMomentum(
577
            var.handle,
578
            mom.handle,
579
            tf.cast(self._learning_rate_tensor, grad.dtype),
580
            grad,
581
            indices,
582
            tf.cast(self._momentum_tensor, grad.dtype),
583
            use_locking=self._use_locking,
584
            use_nesterov=self._use_nesterov)
585

586
      def _prepare(self):
587
        learning_rate = self._learning_rate
588
        if callable(learning_rate):
589
          learning_rate = learning_rate()
590
        self._learning_rate_tensor = tf.convert_to_tensor(
591
            learning_rate, name='learning_rate')
592
        momentum = self._momentum
593
        if callable(momentum):
594
          momentum = momentum()
595
        self._momentum_tensor = tf.convert_to_tensor(momentum, name='momentum')
596

597
    optimizer = LARSOptimizer(
598
        learning_rate=learning_rate,
599
        weight_decay=params.weight_decay,
600
        skip_list=['batch_norm', 'batchnorm', 'gamma', 'beta', 'bias'],
601
        use_nesterov=True)
602
  else:
603
    raise ValueError(f'Unknown optim_type `{params.optim_type}`')
604
  return learning_rate, optimizer
605

606

607
def get_l2_loss(excluded_keywords=None):
608
  """Traverse `tf.trainable_variables` compute L2 reg. Ignore `batch_norm`."""
609
  def _is_excluded(v):
610
    """Guess whether a variable belongs to `batch_norm`."""
611
    keywords = ['batchnorm', 'batch_norm', 'bn',
612
                'layernorm', 'layer_norm']
613
    if excluded_keywords is not None:
614
      keywords += excluded_keywords
615
    return any([k in v.name.lower() for k in keywords])
616

617
  l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables()
618
               if not _is_excluded(v)]
619
  return tf.add_n(l2_losses)
620

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.