google-research
619 строк · 21.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=logging-format-interpolation
17# pylint: disable=g-direct-tensorflow-import
18
19r"""Common utils."""
20
21import os22import re23import threading24
25from absl import logging26import numpy as np27import tensorflow.compat.v1 as tf28
29
30def get_worker_name(worker_id):31"""Returns `/job:tpu_worker/task:{worker_id}`."""32return f'/job:tpu_worker/task:{worker_id}'33
34
35def get_device_name(worker_id, core_id):36"""Returns `/job:tpu_worker/task:{worker_id}/device:tpu:{core_id}`."""37return f'/job:tpu_worker/task:{worker_id}/device:TPU:{core_id}'38
39
40def count_params():41"""Count model params."""42num_params = sum([np.prod([d.value for d in w.shape])43for w in tf.trainable_variables()44if 'teacher' not in w.name.lower()])45return num_params46
47
48def strip_var_name(var_name):49"""Strips variable name of sub-strings blocking variable name matching.50
51Removes sub-strings that should be ignored when matching checkpointed variable
52names to variable names in the training graph, namely:
53- trailing colon + number, e.g. "W:0" --> "W"
54- partitioning info., e.g. "/a/part_12/b" --> "a/b".
55(Note that checkpointed variables do not have partitioning info in their name,
56while model variables do).
57
58Args:
59var_name: str, variable name.
60
61Returns:
62stripped variable name.
63"""
64# Strip trailing number, e.g. convert "lstm/W_0:0" to "lstm/W_0".65var_name = re.sub(r':\d+$', '', var_name)66# Strip partitioning info, e.g. convert "W_0/part_3/Adagrad" to "W_0/Adagrad".67var_name = re.sub(r'/part_\d+', '', var_name)68return var_name69
70
71def get_saver(max_to_keep=1, restore_ema=False):72"""Constructs a `Saver`."""73var_list = {}74if restore_ema:75logging.info('Restore EMA values')76for v in tf.global_variables():77if v.name.startswith('ema'):78logging.fatal(f'wrong ema var name `{v.name}`')79if 'global_step' in v.name:80var_list['global_step'] = v81else:82var_list['ema/' + strip_var_name(v.name)] = v83else:84for v in tf.global_variables():85var_list[strip_var_name(v.name)] = v86saver = tf.train.Saver(var_list,87max_to_keep=max_to_keep,88save_relative_paths=True)89return saver90
91
92class AsyncCheckpoint(object):93"""Saves checkpoint using a separated thread."""94
95def __init__(self, saver, ckpt_dir, max_to_keep=None):96self._saver = saver97self._ckpt_dir = ckpt_dir98self._max_to_keep = max_to_keep99self._thread = None100self.latest_checkpoint = None101
102def join(self):103if self._thread is not None:104self._thread.join()105
106def save(self, sess, step):107"""Docs."""108
109def _save_fn():110"""Run the saver process."""111raw_sess = sess if isinstance(sess, tf.Session) else sess.raw_session()112ckpt_path = self._saver.save(113raw_sess,114save_path=os.path.join(self._ckpt_dir, 'ckpt'),115global_step=step,116write_meta_graph=False,117write_state=False)118self.latest_checkpoint = ckpt_path[len(self._ckpt_dir) + 1:]119logging.info(f'Saved checkpoint `{ckpt_path}`')120
121all_checkpoints = get_all_checkpoints(self._ckpt_dir)122assert all_checkpoints is not None123new_ckpt_content = [f'model_checkpoint_path: "{all_checkpoints[-1]}"']124if (self._max_to_keep is not None and125self._max_to_keep < len(all_checkpoints)):126pattern = all_checkpoints[0] + '*'127tf.io.gfile.BulkDelete(tf.io.gfile.Glob(pattern))128# pylint: disable=invalid-unary-operand-type129all_checkpoints = all_checkpoints[-self._max_to_keep:]130# pylint: enable=invalid-unary-operand-type131for ckpt_name in all_checkpoints:132new_ckpt_content.append(f'all_model_checkpoint_paths: "{ckpt_name}"')133checkpoint_file = os.path.join(self._ckpt_dir, 'checkpoint')134with tf.io.gfile.GFile(checkpoint_file, 'w') as fout:135fout.write('\n'.join(new_ckpt_content))136
137if self._thread is not None:138self._thread.join(timeout=0.1)139if self._thread.is_alive():140logging.info('Saver thread still in progress, skipping checkpoint.')141return142
143self._thread = threading.Thread(target=_save_fn)144self._thread.start()145
146
147def should_log(params):148"""Returns a Boolean `tf.Tensor` dictating whether we should log values."""149global_step = tf.train.get_or_create_global_step()150first_run = tf.equal(global_step, 1)151log_every = tf.equal(tf.floormod(global_step, params.log_every), 0)152return tf.logical_or(first_run, log_every)153
154
155def get_all_checkpoints(ckpt_dir):156"""Returns a list of all checkpoints, eg `['ckpt-100', 'ckpt-500']`."""157if not tf.io.gfile.IsDirectory(ckpt_dir):158return []159pattern = ckpt_dir + '/ckpt-*'160s = len(ckpt_dir) + len('/ckpt-')161checkpoints = [int(f.split('.')[0][s:]) for f in tf.io.gfile.Glob(pattern)]162checkpoints = [os.path.join(ckpt_dir, 'ckpt-{0}'.format(v))163for v in sorted(set(checkpoints))]164return checkpoints165
166
167def get_latest_checkpoint(ckpt_dir):168"""Returns a list of all checkpoints, eg `['ckpt-100', 'ckpt-500']`."""169all_checkpoints = get_all_checkpoints(ckpt_dir)170all_checkpoints = [ckpt for ckpt in all_checkpoints if 'temp' not in ckpt]171if all_checkpoints:172return all_checkpoints[-1]173else:174return None175
176
177def get_outfeed_ops(params, signature):178"""Create TPU outfeed ops."""179outfeed_dtypes, outfeed_shapes = [], []180for dtype, shape in signature.values():181outfeed_dtypes.append(dtype)182outfeed_shapes.append(shape)183outfeed_ops = []184outfeed_graph = tf.Graph()185
186dev_assign = params.device_assignment187host_to_tpus = {}188for replica_id in range(params.num_replicas):189host_device = dev_assign.host_device(replica=replica_id, logical_core=0)190tpu_ordinal = dev_assign.tpu_ordinal(replica=replica_id, logical_core=0)191if host_device not in host_to_tpus:192host_to_tpus[host_device] = [tpu_ordinal]193else:194assert tpu_ordinal not in host_to_tpus[host_device]195host_to_tpus[host_device].append(tpu_ordinal)196
197with outfeed_graph.as_default():198for host, tpus in host_to_tpus.items():199with tf.device(host):200for device_ordinal in tpus:201device_outfeed = tf.raw_ops.OutfeedDequeueTuple(202dtypes=outfeed_dtypes,203shapes=outfeed_shapes,204device_ordinal=device_ordinal)205outfeed_ops.append(device_outfeed)206
207return outfeed_ops, outfeed_graph208
209
210class InfeedThread(object):211"""InfeedTread wrapper."""212
213def __init__(self, params, infeed_ops, infeed_graphs, name='infeed_thread'):214if infeed_graphs is not None:215assert isinstance(infeed_graphs, list)216assert len(infeed_graphs) == len(infeed_ops)217
218self.infeed_ops = infeed_ops219self.infeed_graphs = infeed_graphs220
221self.sessions = []222for g in infeed_graphs:223with g.as_default():224sess = tf.Session(target=params.master, graph=g)225self.sessions.append(sess)226
227self.name = name228self._threads = []229
230def stop(self):231self.join()232for sess in self.sessions:233sess.close()234
235def join(self):236for thread in self._threads:237if thread is not None:238thread.join(timeout=0.1)239del thread240
241def start(self, verbose=False):242"""Docs."""243if verbose:244logging.info(f'Start thread for `{self.name}`')245
246def _infeed_fn(sess, infeed_op, infeed_graph):247"""Run the infeed process."""248with infeed_graph.as_default():249sess.run(infeed_op)250
251for sess, op, g in zip(self.sessions, self.infeed_ops, self.infeed_graphs):252thread = threading.Thread(target=_infeed_fn, args=(sess, op, g))253thread.daemon = True254thread.start()255self._threads.append(thread)256
257
258class OutfeedThread(object):259"""OutfeedThread wrapper."""260
261def __init__(self, params, outfeed_ops, outfeed_graph, outfeed_signature,262name='outfeed_thread'):263self.params = params264self.outfeed_ops = outfeed_ops265self.outfeed_graph = outfeed_graph266self.outfeed_signature = outfeed_signature267
268with outfeed_graph.as_default():269self.session = tf.Session(target=params.master, graph=outfeed_graph)270
271self.name = name272self._thread = None273
274def join(self):275if self._thread is not None:276self._thread.join(timeout=0.1)277self._thread = None278self.session.close()279
280def start(self, verbose=False):281"""Docs."""282if verbose:283logging.info(f'Start thread for `{self.name}`')284if self._thread is not None:285return286
287params = self.params288outfeed_signature = self.outfeed_signature289
290def _outfeed_fn():291"""Read from `outfeed_dequeue` and write `Summary`."""292train_logdir = os.path.join(params.output_dir, 'logs', 'train')293summary_writer = tf.summary.FileWriter(train_logdir)294summary_tags = list(outfeed_signature.keys())295while True:296outfeeds = self.session.run(self.outfeed_ops)297outfeeds = np.array(outfeeds).reshape([params.num_replicas, -1])298outfeeds = np.sum(outfeeds, axis=0).tolist()299summary_values = []300for tag, value in zip(summary_tags, outfeeds):301if tag == 'global_step':302value /= params.num_replicas303step = value304else:305summary_values.append(tf.Summary.Value(tag=tag, simple_value=value))306summary_writer.add_summary(tf.Summary(value=summary_values), step)307summary_writer.flush()308if step >= params.num_train_steps:309summary_writer.close()310break311
312self._thread = threading.Thread(target=_outfeed_fn)313self._thread.daemon = True314self._thread.start()315
316
317def setup_ema(params, name_scope=None):318"""Create exponential moving average for all variables under `name_scope`."""319logging.info(f'ema_decay with rate {params.ema_decay}')320all_vars = tf.global_variables()321ema_ops = []322step = tf.cast(tf.train.get_or_create_global_step() - params.ema_start,323tf.float32)324decay = 1. - tf.minimum(params.ema_decay, (step+1.) / (step+10.))325decay = tf.cond(tf.train.get_or_create_global_step() < params.ema_start,326lambda: tf.constant(1, tf.float32), lambda: decay)327
328def should_skip(v):329key_words = ['momentum', 'rms', 'global_step', 'debug', 'adam', 'lars']330conditions = [k in v.name.lower() for k in key_words]331if name_scope is not None:332conditions += [not v.name.lower().startswith(name_scope)]333return any(conditions)334
335def get_init(v_name):336key_words = ['variance', 'beta']337if any([k in v_name for k in key_words]):338return tf.initializers.ones()339return tf.initializers.zeros()340
341with tf.variable_scope('ema'):342for v in all_vars:343if not should_skip(v):344v_name = strip_var_name(v.name)345with tf.device(v.device):346ema_var = tf.get_variable(347name=v_name,348shape=v.shape.as_list(),349initializer=get_init(v_name),350trainable=False)351ema_op = tf.assign_sub(ema_var, decay * (ema_var-v), use_locking=True)352ema_ops.append(ema_op)353ema_op = tf.group(*ema_ops)354return ema_op355
356
357def get_session(params, isolate_session_state=True):358"""Builds and returns a `tf.Session`."""359config = tf.ConfigProto(360isolate_session_state=isolate_session_state,361allow_soft_placement=True,362graph_options=tf.GraphOptions(363optimizer_options=tf.OptimizerOptions(364opt_level=tf.OptimizerOptions.L0,365do_common_subexpression_elimination=False,366do_function_inlining=False,367do_constant_folding=False)))368return tf.Session(target=params.master, config=config)369
370
371def get_learning_rate(params, initial_lr=None, num_warmup_steps=None,372num_wait_steps=None):373"""Build learning rate."""374global_step = tf.train.get_or_create_global_step()375
376if initial_lr is None:377initial_lr = params.lr378initial_lr = initial_lr * params.train_batch_size / 256.379
380if num_warmup_steps is None:381num_warmup_steps = params.num_warmup_steps382
383if num_wait_steps is not None:384global_step = global_step - num_wait_steps385
386if params.lr_decay_type == 'constant':387lr = tf.constant(initial_lr, dtype=tf.float32)388elif params.lr_decay_type == 'exponential':389lr = tf.train.exponential_decay(390learning_rate=initial_lr,391global_step=global_step-num_warmup_steps,392decay_steps=params.num_decay_steps,393decay_rate=params.lr_decay_rate,394staircase=True)395elif params.lr_decay_type == 'cosine':396if num_wait_steps is None:397lr = tf.train.cosine_decay(398learning_rate=initial_lr,399global_step=global_step-num_warmup_steps,400decay_steps=params.num_train_steps-num_warmup_steps,401alpha=0.0)402else:403lr = tf.train.cosine_decay(404learning_rate=initial_lr,405global_step=global_step-num_warmup_steps,406decay_steps=params.num_train_steps-num_warmup_steps-num_wait_steps,407alpha=0.0)408else:409raise ValueError(f'Unknown lr_decay_type `{params.lr_decay_type}`')410
411r = (tf.cast(global_step+1, tf.float32) /412tf.cast(num_warmup_steps, tf.float32))413warmup_lr = initial_lr * r414lr = tf.cond(global_step < num_warmup_steps, lambda: warmup_lr, lambda: lr)415
416if num_wait_steps is not None:417lr = tf.cond(global_step < 0,418lambda: tf.constant(0., tf.float32), lambda: lr)419
420return lr421
422
423def get_optimizer(params, learning_rate=None):424"""Build optimizer."""425if learning_rate is None:426learning_rate = get_learning_rate(params)427
428if params.optim_type.lower() == 'sgd':429logging.info('Use SGD')430optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate,431use_locking=True)432elif params.optim_type.lower() == 'momentum':433logging.info('Use Momentum')434optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,435momentum=0.9,436use_nesterov=True,437use_locking=True)438elif params.optim_type.lower() == 'rmsprop':439optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,440decay=params.rmsprop_rho,441momentum=params.rmsprop_momentum,442epsilon=params.rmsprop_epsilon,443use_locking=True)444elif params.optim_type.lower() == 'lars':445class LARSOptimizer(tf.train.Optimizer):446"""Layer-wise Adaptive Rate Scaling for large batch training.447
448Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
449I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
450
451Implements the LARS learning rate scheme presented in the paper above.
452This optimizer is useful when scaling the batch size to up to 32K without
453significant performance degradation. It is recommended to use the
454optimizer in conjunction with:
455- Gradual learning rate warm-up
456- Linear learning rate scaling
457- Poly rule learning rate decay
458
459Note, LARS scaling is currently only enabled for dense tensors. Sparse
460tensors use the default momentum optimizer.
461"""
462
463def __init__(464self,465learning_rate,466momentum=0.9,467weight_decay=0.0001,468# The LARS coefficient is a hyperparameter469eeta=0.001,470epsilon=0.0,471name='LARSOptimizer',472# Enable skipping variables from LARS scaling.473# TODO(sameerkm): Enable a direct mechanism to pass a474# subset of variables to the optimizer.475skip_list=None,476use_nesterov=False):477"""Construct a new LARS Optimizer.478
479Args:
480learning_rate: A `Tensor` or floating point value.
481momentum: A floating point value. Momentum hyperparameter.
482weight_decay: A floating point value. Weight decay hyperparameter.
483eeta: LARS coefficient as used in the paper. Dfault set to LARS
484coefficient from the paper. (eeta / weight_decay) determines the
485highest scaling factor in LARS.
486epsilon: Optional epsilon parameter to be set in models that have very
487small gradients. Default set to 0.0.
488name: Optional name prefix for variables and ops created.
489skip_list: List of strings to enable skipping variables from scaling.
490If any of the strings in skip_list is a subset of var.name, variable
491'var' is skipped from LARS scaling. For a typical classification
492model with batch normalization, the skip_list is
493['batch_normalization', 'bias']
494use_nesterov: when set to True, nesterov momentum will be enabled
495
496Raises:
497ValueError: If a hyperparameter is set to a non-sensical value.
498"""
499if momentum < 0.0:500raise ValueError(f'momentum should be positive: {momentum}')501if weight_decay < 0.0:502raise ValueError(f'weight_decay should be positive: {weight_decay}')503super(LARSOptimizer, self).__init__(use_locking=False, name=name)504
505self._learning_rate = learning_rate506self._momentum = momentum507self._weight_decay = weight_decay508self._eeta = eeta509self._epsilon = epsilon510self._name = name511self._skip_list = skip_list512self._use_nesterov = use_nesterov513
514def _create_slots(self, var_list):515for v in var_list:516self._zeros_slot(v, 'momentum', self._name)517
518def compute_lr(self, grad, var):519scaled_lr = self._learning_rate520if self._skip_list is None or not any(v in var.name521for v in self._skip_list):522w_norm = tf.norm(var, ord=2)523g_norm = tf.norm(grad, ord=2)524trust_ratio = tf.where(525tf.math.greater(w_norm, 0),526tf.where(527tf.math.greater(g_norm, 0),528(self._eeta * w_norm / (529g_norm + self._weight_decay * w_norm + self._epsilon)),5301.0),5311.0)532scaled_lr = self._learning_rate * trust_ratio533# Add the weight regularization gradient534grad = grad + self._weight_decay * var535return scaled_lr, grad536
537def _apply_dense(self, grad, var):538scaled_lr, grad = self.compute_lr(grad, var)539mom = self.get_slot(var, 'momentum')540return tf.raw_ops.ApplyMomentum(541var,542mom,543tf.cast(1.0, var.dtype.base_dtype),544grad * scaled_lr,545self._momentum,546use_locking=False,547use_nesterov=self._use_nesterov)548
549def _resource_apply_dense(self, grad, var):550scaled_lr, grad = self.compute_lr(grad, var)551mom = self.get_slot(var, 'momentum')552return tf.raw_ops.ResourceApplyMomentum(553var=var.handle,554accum=mom.handle,555lr=tf.cast(1.0, var.dtype.base_dtype),556grad=grad * scaled_lr,557momentum=self._momentum,558use_locking=False,559use_nesterov=self._use_nesterov)560
561# Fallback to momentum optimizer for sparse tensors562def _apply_sparse(self, grad, var):563mom = self.get_slot(var, 'momentum')564return tf.raw_ops.SparseApplyMomentum(565var,566mom,567tf.cast(self._learning_rate_tensor, var.dtype.base_dtype),568grad.values,569grad.indices,570tf.cast(self._momentum_tensor, var.dtype.base_dtype),571use_locking=self._use_locking,572use_nesterov=self._use_nesterov).op573
574def _resource_apply_sparse(self, grad, var, indices):575mom = self.get_slot(var, 'momentum')576return tf.raw_ops.ResourceSparseApplyMomentum(577var.handle,578mom.handle,579tf.cast(self._learning_rate_tensor, grad.dtype),580grad,581indices,582tf.cast(self._momentum_tensor, grad.dtype),583use_locking=self._use_locking,584use_nesterov=self._use_nesterov)585
586def _prepare(self):587learning_rate = self._learning_rate588if callable(learning_rate):589learning_rate = learning_rate()590self._learning_rate_tensor = tf.convert_to_tensor(591learning_rate, name='learning_rate')592momentum = self._momentum593if callable(momentum):594momentum = momentum()595self._momentum_tensor = tf.convert_to_tensor(momentum, name='momentum')596
597optimizer = LARSOptimizer(598learning_rate=learning_rate,599weight_decay=params.weight_decay,600skip_list=['batch_norm', 'batchnorm', 'gamma', 'beta', 'bias'],601use_nesterov=True)602else:603raise ValueError(f'Unknown optim_type `{params.optim_type}`')604return learning_rate, optimizer605
606
607def get_l2_loss(excluded_keywords=None):608"""Traverse `tf.trainable_variables` compute L2 reg. Ignore `batch_norm`."""609def _is_excluded(v):610"""Guess whether a variable belongs to `batch_norm`."""611keywords = ['batchnorm', 'batch_norm', 'bn',612'layernorm', 'layer_norm']613if excluded_keywords is not None:614keywords += excluded_keywords615return any([k in v.name.lower() for k in keywords])616
617l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables()618if not _is_excluded(v)]619return tf.add_n(l2_losses)620