google-research
405 строк · 15.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Finetune a NMT model with a mixture of in-domain and OOD data.
17
18This script trains a Transformer on a WMT dataset. This runner is
19intended for the special case where finetuning data is augmented
20with high quality out of domain data.
21"""
22
23import functools
24import os
25
26from absl import app
27from absl import flags
28from absl import logging
29from clu import metric_writers
30from clu import periodic_actions
31from flax import jax_utils
32from flax import linen as nn
33from flax import optim
34from flax.training import checkpoints
35from flax.training import common_utils
36import jax
37import jax.numpy as jnp
38import numpy as np
39import tensorflow as tf
40
41from data_selection.wmt import common
42from data_selection.wmt import decode
43from data_selection.wmt import input_pipeline
44from data_selection.wmt import models
45from data_selection.wmt import train_util
46
47FLAGS = flags.FLAGS
48flags.adopt_module_key_flags(train_util)
49
50
51def main(argv):
52if len(argv) > 1:
53raise app.UsageError('Too many command-line arguments.')
54
55# Make sure tf does not allocate gpu memory.
56tf.config.experimental.set_visible_devices([], 'GPU')
57
58if FLAGS.jax_backend_target:
59jax.config.update('jax_xla_backend', 'tpu_driver')
60jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
61
62# Number of local devices for this host.
63n_devices = jax.local_device_count()
64
65if jax.process_index() == 0:
66tf.io.gfile.makedirs(FLAGS.model_dir)
67
68if FLAGS.batch_size % n_devices:
69raise ValueError('Batch size must be divisible by the number of devices')
70
71vocab_path = FLAGS.vocab_path
72if vocab_path is None:
73vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
74tf.io.gfile.makedirs(os.path.split(vocab_path)[0])
75
76# Load Dataset
77# ---------------------------------------------------------------------------
78logging.info('Initializing dataset.')
79train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
80dataset_name=FLAGS.dataset_name,
81eval_dataset_name=FLAGS.eval_dataset_name,
82shard_idx=jax.process_index(),
83shard_count=jax.process_count(),
84data_dir=FLAGS.data_dir,
85vocab_path=vocab_path,
86target_vocab_size=FLAGS.vocab_size,
87batch_size=FLAGS.batch_size,
88max_length=FLAGS.max_target_length,
89max_eval_length=FLAGS.max_eval_target_length,
90paracrawl_size=FLAGS.paracrawl_size,
91is_scores_path=FLAGS.is_scores_path,
92num_to_keep=FLAGS.data_selection_size,
93pseudo_path=FLAGS.pseudo_path,
94repeat_count=FLAGS.repeat_count,
95newscommentary_size=FLAGS.newscommentary_size,
96split_tokenizer=FLAGS.split_tokenizer)
97
98if FLAGS.aux_eval_dataset:
99aux_datasets = []
100aux_names = FLAGS.aux_eval_dataset.split(',')
101for name in aux_names:
102_, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
103dataset_name=name,
104eval_dataset_name=None,
105shard_idx=jax.process_index(),
106shard_count=jax.process_count(),
107data_dir=FLAGS.data_dir,
108vocab_path=vocab_path,
109target_vocab_size=FLAGS.vocab_size,
110batch_size=FLAGS.batch_size,
111max_length=FLAGS.max_target_length,
112max_eval_length=FLAGS.max_eval_target_length,
113paracrawl_size=FLAGS.paracrawl_size,
114is_scores_path=FLAGS.is_scores_path,
115num_to_keep=FLAGS.data_selection_size,
116pseudo_path=FLAGS.pseudo_path,
117repeat_count=FLAGS.repeat_count,
118newscommentary_size=FLAGS.newscommentary_size)
119aux_datasets.append(aux_eval_ds)
120
121train_iter = iter(train_ds)
122vocab_size = int(encoder.vocab_size())
123eos_id = decode.EOS_ID # Default Sentencepiece EOS token.
124
125def decode_tokens(toks):
126valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
127return encoder.detokenize(valid_toks).numpy().decode('utf-8')
128
129logging.info('Initializing model, optimizer, and step functions.')
130
131# Build Model and Optimizer
132# ---------------------------------------------------------------------------
133train_config = models.TransformerConfig(
134vocab_size=vocab_size,
135output_vocab_size=vocab_size,
136share_embeddings=FLAGS.share_embeddings,
137logits_via_embedding=FLAGS.logits_via_embedding,
138dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
139emb_dim=FLAGS.emb_dim,
140num_heads=FLAGS.num_heads,
141num_layers=FLAGS.num_layers,
142qkv_dim=FLAGS.qkv_dim,
143mlp_dim=FLAGS.mlp_dim,
144max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
145dropout_rate=FLAGS.dropout_rate,
146attention_dropout_rate=FLAGS.attention_dropout_rate,
147deterministic=False,
148decode=False,
149kernel_init=nn.initializers.xavier_uniform(),
150bias_init=nn.initializers.normal(stddev=1e-6))
151eval_config = train_config.replace(deterministic=True)
152predict_config = train_config.replace(deterministic=True, decode=True)
153
154start_step = 0
155rng = jax.random.PRNGKey(FLAGS.random_seed)
156rng, init_rng = jax.random.split(rng)
157# It's possible that is supposed to be per device batch size
158input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
159target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
160
161m = models.Transformer(eval_config)
162initial_variables = jax.jit(m.init)(init_rng,
163jnp.ones(input_shape, jnp.float32),
164jnp.ones(target_shape, jnp.float32))
165
166# apply an optimizer to this tree
167optimizer_def = optim.Adam(
168FLAGS.learning_rate,
169beta1=0.9,
170beta2=0.98,
171eps=1e-9,
172weight_decay=FLAGS.weight_decay)
173optimizer = optimizer_def.create(initial_variables['params'])
174
175# We access model params only from optimizer below via optimizer.target.
176del initial_variables
177
178if FLAGS.restore_checkpoints:
179logging.info('Restoring checkpoint.')
180# If we have a pretrained model, use that. Else, just continue where leftoff
181model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
182optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
183# Grab last step.
184start_step = int(optimizer.state.step)
185
186writer = metric_writers.create_default_writer(
187FLAGS.model_dir, just_logging=jax.process_index() > 0)
188
189flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
190]
191if flag_key:
192flag_key = flag_key[0]
193local_flags = {
194f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key]
195}
196writer.write_hparams(local_flags)
197
198# Replicate optimizer.
199optimizer = jax_utils.replicate(optimizer)
200
201learning_rate_fn = common.create_learning_rate_scheduler(
202base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps,
203steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step,
204finetune_lr=FLAGS.finetune_lr)
205
206# compile multidevice versions of train/eval/predict step and cache init fn.
207p_train_step = jax.pmap(
208functools.partial(
209train_util.train_step,
210config=train_config,
211learning_rate_fn=learning_rate_fn,
212label_smoothing=FLAGS.label_smoothing),
213axis_name='batch',
214donate_argnums=(0,)) # pytype: disable=wrong-arg-types
215p_eval_step = jax.pmap(
216functools.partial(train_util.eval_step, config=eval_config),
217axis_name='batch')
218p_init_cache = jax.pmap(
219functools.partial(
220train_util.initialize_cache,
221max_decode_len=FLAGS.max_predict_length,
222config=predict_config),
223axis_name='batch')
224p_pred_step = jax.pmap(
225functools.partial(
226train_util.predict_step,
227config=predict_config,
228beam_size=FLAGS.beam_size),
229axis_name='batch',
230static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant
231
232# Main Train Loop
233# ---------------------------------------------------------------------------
234
235# We init the first set of dropout PRNG keys, but update it afterwards inside
236# the main pmap"d training update for performance.
237dropout_rngs = jax.random.split(rng, jax.local_device_count())
238del rng
239
240logging.info('Starting training loop.')
241hooks = []
242report_progress = periodic_actions.ReportProgress(
243num_train_steps=FLAGS.num_train_steps, writer=writer)
244if jax.process_index() == 0:
245hooks += [
246report_progress,
247periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5)
248]
249train_metrics = []
250total_steps = start_step + FLAGS.num_train_steps
251if FLAGS.eval_only:
252total_steps = start_step + 1
253best_eval_loss = 1000
254curr_eval_loss = 1000
255eval_loss_history = []
256do_resample_data = False
257gradual_selection_size = FLAGS.data_selection_size
258eval_freq = FLAGS.eval_frequency
259with metric_writers.ensure_flushes(writer):
260for step in range(start_step, total_steps):
261is_last_step = step == total_steps - 1
262
263# Resample training data for gradual FT
264if do_resample_data:
265# resample data
266do_resample_data = False
267if eval_loss_history[-1] > eval_loss_history[-2]:
268gradual_selection_size = int(gradual_selection_size / .75)
269else:
270gradual_selection_size = int(.75 * gradual_selection_size)
271if gradual_selection_size < 500_000:
272eval_freq = int(gradual_selection_size) / 100
273
274train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
275dataset_name=FLAGS.dataset_name,
276eval_dataset_name=FLAGS.eval_dataset_name,
277shard_idx=jax.process_index(),
278shard_count=jax.process_count(),
279data_dir=FLAGS.data_dir,
280vocab_path=vocab_path,
281target_vocab_size=FLAGS.vocab_size,
282batch_size=FLAGS.batch_size,
283max_length=FLAGS.max_target_length,
284max_eval_length=FLAGS.max_eval_target_length,
285paracrawl_size=FLAGS.paracrawl_size,
286is_scores_path=FLAGS.is_scores_path,
287num_to_keep=gradual_selection_size,
288pseudo_path=FLAGS.pseudo_path,
289repeat_count=FLAGS.repeat_count,
290newscommentary_size=FLAGS.newscommentary_size,
291split_tokenizer=FLAGS.split_tokenizer)
292train_iter = iter(train_ds)
293logging.info('Decrease selection size to %d at step %d',
294gradual_selection_size, step)
295
296# Shard data to devices and do a training step.
297if not FLAGS.eval_only:
298with jax.profiler.StepTraceAnnotation('train', step_num=step):
299try:
300batch = common_utils.shard(
301jax.tree_map(np.asarray, next(train_iter)))
302optimizer, metrics = p_train_step(
303optimizer, batch, dropout_rng=dropout_rngs)
304train_metrics.append(metrics)
305except StopIteration:
306is_last_step = True
307
308# Quick indication that training is happening.
309logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
310for h in hooks:
311h(step)
312
313# Periodic metric handling.
314if (step - start_step) % eval_freq == 0 or is_last_step:
315if not FLAGS.eval_only:
316with report_progress.timed('training_metrics'):
317logging.info('Gathering training metrics.')
318train_metrics = common_utils.get_metrics(train_metrics)
319lr = train_metrics.pop('learning_rate').mean()
320metrics_sums = jax.tree_map(jnp.sum, train_metrics)
321denominator = metrics_sums.pop('denominator')
322summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
323summary['learning_rate'] = lr
324summary = {'train_' + k: v for k, v in summary.items()}
325writer.write_scalars(step, summary)
326train_metrics = []
327
328if FLAGS.eval_only:
329p_eval_per_pos_step = jax.pmap(
330functools.partial(
331train_util.eval_per_pos_step, config=eval_config),
332axis_name='batch')
333# Get per example loss
334loss_filename = FLAGS.model_dir + '/test_losses.csv'
335train_util.write_per_example_losses(
336p_eval_step=p_eval_per_pos_step,
337target=optimizer.target,
338eval_ds=eval_ds,
339num_eval_steps=FLAGS.num_eval_steps,
340loss_filename=loss_filename)
341else:
342with report_progress.timed('eval'):
343eval_results = train_util.evaluate(
344p_eval_step=p_eval_step,
345target=optimizer.target,
346eval_ds=eval_ds,
347num_eval_steps=FLAGS.num_eval_steps)
348curr_eval_loss = eval_results['loss']
349eval_loss_history.append(curr_eval_loss)
350if len(eval_loss_history) > 1:
351orig_loss = eval_loss_history[-2]
352percent_change = (orig_loss - curr_eval_loss) / orig_loss
353percent_change *= 100
354if percent_change < .1:
355do_resample_data = True
356writer.write_scalars(
357step, {'eval_' + k: v for k, v in eval_results.items()})
358
359if FLAGS.aux_eval_dataset:
360for aux_i, aux_eval_ds in enumerate(aux_datasets):
361with report_progress.timed('aux_eval'):
362eval_results = train_util.evaluate(
363p_eval_step=p_eval_step,
364target=optimizer.target,
365eval_ds=aux_eval_ds,
366num_eval_steps=FLAGS.num_eval_steps)
367writer.write_scalars(
368step, {
369'aux' + str(aux_i) + '_eval_' + k: v
370for k, v in eval_results.items()
371})
372
373if FLAGS.compute_bleu:
374with report_progress.timed('translate_and_bleu'):
375decode_file = FLAGS.model_dir + '/decodes.csv'
376exemplars, bleu_score = train_util.translate_and_calculate_bleu(
377p_pred_step=p_pred_step,
378p_init_cache=p_init_cache,
379target=optimizer.target,
380predict_ds=predict_ds,
381decode_tokens=decode_tokens,
382max_predict_length=FLAGS.max_predict_length,
383num_eval_steps=FLAGS.num_eval_steps,
384decode_file=decode_file if FLAGS.eval_only else '')
385writer.write_scalars(step, {'bleu': bleu_score})
386writer.write_texts(step, {'samples': exemplars})
387
388# Save a checkpoint on one host after every checkpoint_freq steps.
389save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or
390is_last_step)
391if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
392) == 0:
393if curr_eval_loss < best_eval_loss: # only save better checkpoints
394best_eval_loss = curr_eval_loss
395with report_progress.timed('checkpoint'):
396checkpoints.save_checkpoint(
397FLAGS.model_dir, jax_utils.unreplicate(optimizer),
398step, keep=FLAGS.chkpts_to_keep, overwrite=True)
399
400if is_last_step:
401break
402
403
404if __name__ == '__main__':
405app.run(main)
406