google-research
409 строк · 15.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Gradual training for NMT.
17
18This script trains a Transformer on a WMT dataset.
19Gradual training refers to the periodic decrease in the
20out of domain dataset size. This is similar to the
21gradual finetining proposed in dynamic data selection.
22"""
23
24# pytype: disable=wrong-arg-count
25# pytype: disable=attribute-error
26
27import functools28import os29
30from absl import app31from absl import flags32from absl import logging33from clu import metric_writers34from clu import periodic_actions35from flax import jax_utils36from flax import linen as nn37from flax import optim38from flax.training import checkpoints39from flax.training import common_utils40import jax41import jax.numpy as jnp42import numpy as np43import tensorflow as tf44
45from data_selection.wmt import common46from data_selection.wmt import decode47from data_selection.wmt import input_pipeline48from data_selection.wmt import models49from data_selection.wmt import train_util50
51FLAGS = flags.FLAGS52flags.adopt_module_key_flags(train_util)53
54
55def main(argv):56if len(argv) > 1:57raise app.UsageError('Too many command-line arguments.')58
59# Make sure tf does not allocate gpu memory.60tf.config.experimental.set_visible_devices([], 'GPU')61
62if FLAGS.jax_backend_target:63jax.config.update('jax_xla_backend', 'tpu_driver')64jax.config.update('jax_backend_target', FLAGS.jax_backend_target)65
66# Number of local devices for this host.67n_devices = jax.local_device_count()68
69if jax.process_index() == 0:70tf.io.gfile.makedirs(FLAGS.model_dir)71
72if FLAGS.batch_size % n_devices:73raise ValueError('Batch size must be divisible by the number of devices')74
75vocab_path = FLAGS.vocab_path76if vocab_path is None:77vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')78tf.io.gfile.makedirs(os.path.split(vocab_path)[0])79
80# Load Dataset81# ---------------------------------------------------------------------------82logging.info('Initializing dataset.')83train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(84dataset_name=FLAGS.dataset_name,85eval_dataset_name=FLAGS.eval_dataset_name,86shard_idx=jax.process_index(),87shard_count=jax.process_count(),88data_dir=FLAGS.data_dir,89vocab_path=vocab_path,90target_vocab_size=FLAGS.vocab_size,91batch_size=FLAGS.batch_size,92max_length=FLAGS.max_target_length,93max_eval_length=FLAGS.max_eval_target_length,94paracrawl_size=FLAGS.paracrawl_size,95is_scores_path=FLAGS.is_scores_path,96num_to_keep=FLAGS.data_selection_size,97pseudo_path=FLAGS.pseudo_path,98repeat_count=FLAGS.repeat_count,99newscommentary_size=FLAGS.newscommentary_size,100split_tokenizer=FLAGS.split_tokenizer)101
102if FLAGS.aux_eval_dataset:103aux_datasets = []104aux_names = FLAGS.aux_eval_dataset.split(',')105for name in aux_names:106_, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(107dataset_name=name,108eval_dataset_name=None,109shard_idx=jax.process_index(),110shard_count=jax.process_count(),111data_dir=FLAGS.data_dir,112vocab_path=vocab_path,113target_vocab_size=FLAGS.vocab_size,114batch_size=FLAGS.batch_size,115max_length=FLAGS.max_target_length,116max_eval_length=FLAGS.max_eval_target_length,117paracrawl_size=FLAGS.paracrawl_size,118is_scores_path=FLAGS.is_scores_path,119num_to_keep=FLAGS.data_selection_size,120pseudo_path=FLAGS.pseudo_path,121repeat_count=FLAGS.repeat_count,122newscommentary_size=FLAGS.newscommentary_size)123aux_datasets.append(aux_eval_ds)124
125train_iter = iter(train_ds)126vocab_size = int(encoder.vocab_size())127eos_id = decode.EOS_ID # Default Sentencepiece EOS token.128
129def decode_tokens(toks):130valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)131return encoder.detokenize(valid_toks).numpy().decode('utf-8')132
133logging.info('Initializing model, optimizer, and step functions.')134
135# Build Model and Optimizer136# ---------------------------------------------------------------------------137train_config = models.TransformerConfig(138vocab_size=vocab_size,139output_vocab_size=vocab_size,140share_embeddings=FLAGS.share_embeddings,141logits_via_embedding=FLAGS.logits_via_embedding,142dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,143emb_dim=FLAGS.emb_dim,144num_heads=FLAGS.num_heads,145num_layers=FLAGS.num_layers,146qkv_dim=FLAGS.qkv_dim,147mlp_dim=FLAGS.mlp_dim,148max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),149dropout_rate=FLAGS.dropout_rate,150attention_dropout_rate=FLAGS.attention_dropout_rate,151deterministic=False,152decode=False,153kernel_init=nn.initializers.xavier_uniform(),154bias_init=nn.initializers.normal(stddev=1e-6))155eval_config = train_config.replace(deterministic=True)156predict_config = train_config.replace(deterministic=True, decode=True)157
158start_step = 0159rng = jax.random.PRNGKey(FLAGS.random_seed)160rng, init_rng = jax.random.split(rng)161# It's possible that is supposed to be per device batch size162input_shape = (FLAGS.batch_size, FLAGS.max_target_length)163target_shape = (FLAGS.batch_size, FLAGS.max_target_length)164
165m = models.Transformer(eval_config)166initial_variables = jax.jit(m.init)(init_rng,167jnp.ones(input_shape, jnp.float32),168jnp.ones(target_shape, jnp.float32))169
170# apply an optimizer to this tree171optimizer_def = optim.Adam(172FLAGS.learning_rate,173beta1=0.9,174beta2=0.98,175eps=1e-9,176weight_decay=FLAGS.weight_decay)177optimizer = optimizer_def.create(initial_variables['params'])178
179# We access model params only from optimizer below via optimizer.target.180del initial_variables181
182if FLAGS.restore_checkpoints:183logging.info('Restoring checkpoint.')184# If we have a pretrained model, use that. Else, just continue where leftoff185model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir186optimizer = checkpoints.restore_checkpoint(model_path, optimizer)187# Grab last step.188start_step = int(optimizer.state.step)189
190writer = metric_writers.create_default_writer(191FLAGS.model_dir, just_logging=jax.process_index() > 0)192
193flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k194]195if flag_key:196flag_key = flag_key[0]197local_flags = {198f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key]199}200writer.write_hparams(local_flags)201
202# Replicate optimizer.203optimizer = jax_utils.replicate(optimizer)204
205learning_rate_fn = common.create_learning_rate_scheduler(206base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps,207steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step,208finetune_lr=FLAGS.finetune_lr)209
210# compile multidevice versions of train/eval/predict step and cache init fn.211p_train_step = jax.pmap(212functools.partial(213train_util.train_step,214config=train_config,215learning_rate_fn=learning_rate_fn,216label_smoothing=FLAGS.label_smoothing),217axis_name='batch',218donate_argnums=(0,)) # pytype: disable=wrong-arg-types219p_eval_step = jax.pmap(220functools.partial(train_util.eval_step, config=eval_config),221axis_name='batch')222p_init_cache = jax.pmap(223functools.partial(224train_util.initialize_cache,225max_decode_len=FLAGS.max_predict_length,226config=predict_config),227axis_name='batch')228p_pred_step = jax.pmap(229functools.partial(230train_util.predict_step,231config=predict_config,232beam_size=FLAGS.beam_size),233axis_name='batch',234static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant235
236# Main Train Loop237# ---------------------------------------------------------------------------238
239# We init the first set of dropout PRNG keys, but update it afterwards inside240# the main pmap"d training update for performance.241dropout_rngs = jax.random.split(rng, jax.local_device_count())242del rng243
244logging.info('Starting training loop.')245hooks = []246report_progress = periodic_actions.ReportProgress(247num_train_steps=FLAGS.num_train_steps, writer=writer)248if jax.process_index() == 0:249hooks += [250report_progress,251periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5)252]253train_metrics = []254total_steps = start_step + FLAGS.num_train_steps255if FLAGS.eval_only:256total_steps = start_step + 1257best_eval_loss = 1000258curr_eval_loss = 1000259eval_loss_history = []260last_eval_step = 0261do_resample_data = False262gradual_selection_size = FLAGS.data_selection_size263dynamic_eval_freq = FLAGS.eval_frequency264with metric_writers.ensure_flushes(writer):265for step in range(start_step, total_steps):266is_last_step = step == total_steps - 1267
268# Resample training data for gradual FT269if do_resample_data:270# resample data271do_resample_data = False272gradual_selection_size *= .7273dynamic_eval_freq = int(gradual_selection_size / 1000 / 4)274
275train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(276dataset_name=FLAGS.dataset_name,277eval_dataset_name=FLAGS.eval_dataset_name,278shard_idx=jax.process_index(),279shard_count=jax.process_count(),280data_dir=FLAGS.data_dir,281vocab_path=vocab_path,282target_vocab_size=FLAGS.vocab_size,283batch_size=FLAGS.batch_size,284max_length=FLAGS.max_target_length,285max_eval_length=FLAGS.max_eval_target_length,286paracrawl_size=FLAGS.paracrawl_size,287is_scores_path=FLAGS.is_scores_path,288num_to_keep=int(gradual_selection_size),289pseudo_path=FLAGS.pseudo_path,290repeat_count=FLAGS.repeat_count,291newscommentary_size=FLAGS.newscommentary_size,292split_tokenizer=FLAGS.split_tokenizer)293train_iter = iter(train_ds)294
295# Shard data to devices and do a training step.296if not FLAGS.eval_only:297logging.info('Doing Training.')298with jax.profiler.StepTraceAnnotation('train', step_num=step):299try:300batch = common_utils.shard(301jax.tree_map(np.asarray, next(train_iter)))302optimizer, metrics = p_train_step(303optimizer, batch, dropout_rng=dropout_rngs)304train_metrics.append(metrics)305except StopIteration:306is_last_step = True307
308# Quick indication that training is happening.309logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)310for h in hooks:311h(step)312
313# Periodic metric handling.314if (step - start_step) % dynamic_eval_freq == 0 or is_last_step:315if not FLAGS.eval_only:316with report_progress.timed('training_metrics'):317logging.info('Gathering training metrics.')318train_metrics = common_utils.get_metrics(train_metrics)319lr = train_metrics.pop('learning_rate').mean()320metrics_sums = jax.tree_map(jnp.sum, train_metrics)321denominator = metrics_sums.pop('denominator')322summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop323summary['learning_rate'] = lr324summary = {'train_' + k: v for k, v in summary.items()}325writer.write_scalars(step, summary)326train_metrics = []327
328if FLAGS.eval_only:329p_eval_per_pos_step = jax.pmap(330functools.partial(331train_util.eval_per_pos_step, config=eval_config),332axis_name='batch')333# Get per example loss334loss_filename = FLAGS.model_dir + '/test_losses.csv'335train_util.write_per_example_losses(336p_eval_step=p_eval_per_pos_step,337target=optimizer.target,338eval_ds=eval_ds,339num_eval_steps=FLAGS.num_eval_steps,340loss_filename=loss_filename)341else:342with report_progress.timed('eval'):343eval_results = train_util.evaluate(344p_eval_step=p_eval_step,345target=optimizer.target,346eval_ds=eval_ds,347num_eval_steps=FLAGS.num_eval_steps)348curr_eval_loss = eval_results['loss']349eval_loss_history.append(curr_eval_loss)350if len(eval_loss_history) > 1:351improvement_rate = 0.000004352orig_loss = eval_loss_history[-2]353true_improvement = orig_loss - curr_eval_loss354expected_improvement = (step - last_eval_step) * improvement_rate355# percent_change = (orig_loss - curr_eval_loss) / orig_loss356# percent_change *= 100357if true_improvement < expected_improvement: # percent_change<.1:358do_resample_data = True359last_eval_step = step360writer.write_scalars(361step, {'eval_' + k: v for k, v in eval_results.items()})362
363if FLAGS.aux_eval_dataset:364for aux_i, aux_eval_ds in enumerate(aux_datasets):365with report_progress.timed('aux_eval'):366eval_results = train_util.evaluate(367p_eval_step=p_eval_step,368target=optimizer.target,369eval_ds=aux_eval_ds,370num_eval_steps=FLAGS.num_eval_steps)371writer.write_scalars(372step, {373'aux' + str(aux_i) + '_eval_' + k: v374for k, v in eval_results.items()375})376
377if FLAGS.compute_bleu:378with report_progress.timed('translate_and_bleu'):379decode_file = FLAGS.model_dir + '/decodes.csv'380exemplars, bleu_score = train_util.translate_and_calculate_bleu(381p_pred_step=p_pred_step,382p_init_cache=p_init_cache,383target=optimizer.target,384predict_ds=predict_ds,385decode_tokens=decode_tokens,386max_predict_length=FLAGS.max_predict_length,387num_eval_steps=FLAGS.num_eval_steps,388decode_file=decode_file if FLAGS.eval_only else '')389writer.write_scalars(step, {'bleu': bleu_score})390writer.write_texts(step, {'samples': exemplars})391
392# Save a checkpoint on one host after every checkpoint_freq steps.393save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or394is_last_step)395if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(396) == 0:397if curr_eval_loss < best_eval_loss: # only save better checkpoints398best_eval_loss = curr_eval_loss399with report_progress.timed('checkpoint'):400checkpoints.save_checkpoint(401FLAGS.model_dir, jax_utils.unreplicate(optimizer),402step, keep=FLAGS.chkpts_to_keep, overwrite=True)403
404if is_last_step:405break406
407
408if __name__ == '__main__':409app.run(main)410