google-research

Форк
0
405 строк · 15.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Finetune a NMT model with a mixture of in-domain and OOD data.
17

18
This script trains a Transformer on a WMT dataset. This runner is
19
intended for the special case where finetuning data is augmented
20
with high quality out of domain data.
21
"""
22

23
import functools
24
import os
25

26
from absl import app
27
from absl import flags
28
from absl import logging
29
from clu import metric_writers
30
from clu import periodic_actions
31
from flax import jax_utils
32
from flax import linen as nn
33
from flax import optim
34
from flax.training import checkpoints
35
from flax.training import common_utils
36
import jax
37
import jax.numpy as jnp
38
import numpy as np
39
import tensorflow as tf
40

41
from data_selection.wmt import common
42
from data_selection.wmt import decode
43
from data_selection.wmt import input_pipeline
44
from data_selection.wmt import models
45
from data_selection.wmt import train_util
46

47
FLAGS = flags.FLAGS
48
flags.adopt_module_key_flags(train_util)
49

50

51
def main(argv):
52
  if len(argv) > 1:
53
    raise app.UsageError('Too many command-line arguments.')
54

55
  # Make sure tf does not allocate gpu memory.
56
  tf.config.experimental.set_visible_devices([], 'GPU')
57

58
  if FLAGS.jax_backend_target:
59
    jax.config.update('jax_xla_backend', 'tpu_driver')
60
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
61

62
  # Number of local devices for this host.
63
  n_devices = jax.local_device_count()
64

65
  if jax.process_index() == 0:
66
    tf.io.gfile.makedirs(FLAGS.model_dir)
67

68
  if FLAGS.batch_size % n_devices:
69
    raise ValueError('Batch size must be divisible by the number of devices')
70

71
  vocab_path = FLAGS.vocab_path
72
  if vocab_path is None:
73
    vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
74
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])
75

76
  # Load Dataset
77
  # ---------------------------------------------------------------------------
78
  logging.info('Initializing dataset.')
79
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
80
      dataset_name=FLAGS.dataset_name,
81
      eval_dataset_name=FLAGS.eval_dataset_name,
82
      shard_idx=jax.process_index(),
83
      shard_count=jax.process_count(),
84
      data_dir=FLAGS.data_dir,
85
      vocab_path=vocab_path,
86
      target_vocab_size=FLAGS.vocab_size,
87
      batch_size=FLAGS.batch_size,
88
      max_length=FLAGS.max_target_length,
89
      max_eval_length=FLAGS.max_eval_target_length,
90
      paracrawl_size=FLAGS.paracrawl_size,
91
      is_scores_path=FLAGS.is_scores_path,
92
      num_to_keep=FLAGS.data_selection_size,
93
      pseudo_path=FLAGS.pseudo_path,
94
      repeat_count=FLAGS.repeat_count,
95
      newscommentary_size=FLAGS.newscommentary_size,
96
      split_tokenizer=FLAGS.split_tokenizer)
97

98
  if FLAGS.aux_eval_dataset:
99
    aux_datasets = []
100
    aux_names = FLAGS.aux_eval_dataset.split(',')
101
    for name in aux_names:
102
      _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
103
          dataset_name=name,
104
          eval_dataset_name=None,
105
          shard_idx=jax.process_index(),
106
          shard_count=jax.process_count(),
107
          data_dir=FLAGS.data_dir,
108
          vocab_path=vocab_path,
109
          target_vocab_size=FLAGS.vocab_size,
110
          batch_size=FLAGS.batch_size,
111
          max_length=FLAGS.max_target_length,
112
          max_eval_length=FLAGS.max_eval_target_length,
113
          paracrawl_size=FLAGS.paracrawl_size,
114
          is_scores_path=FLAGS.is_scores_path,
115
          num_to_keep=FLAGS.data_selection_size,
116
          pseudo_path=FLAGS.pseudo_path,
117
          repeat_count=FLAGS.repeat_count,
118
          newscommentary_size=FLAGS.newscommentary_size)
119
      aux_datasets.append(aux_eval_ds)
120

121
  train_iter = iter(train_ds)
122
  vocab_size = int(encoder.vocab_size())
123
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
124

125
  def decode_tokens(toks):
126
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
127
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')
128

129
  logging.info('Initializing model, optimizer, and step functions.')
130

131
  # Build Model and Optimizer
132
  # ---------------------------------------------------------------------------
133
  train_config = models.TransformerConfig(
134
      vocab_size=vocab_size,
135
      output_vocab_size=vocab_size,
136
      share_embeddings=FLAGS.share_embeddings,
137
      logits_via_embedding=FLAGS.logits_via_embedding,
138
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
139
      emb_dim=FLAGS.emb_dim,
140
      num_heads=FLAGS.num_heads,
141
      num_layers=FLAGS.num_layers,
142
      qkv_dim=FLAGS.qkv_dim,
143
      mlp_dim=FLAGS.mlp_dim,
144
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
145
      dropout_rate=FLAGS.dropout_rate,
146
      attention_dropout_rate=FLAGS.attention_dropout_rate,
147
      deterministic=False,
148
      decode=False,
149
      kernel_init=nn.initializers.xavier_uniform(),
150
      bias_init=nn.initializers.normal(stddev=1e-6))
151
  eval_config = train_config.replace(deterministic=True)
152
  predict_config = train_config.replace(deterministic=True, decode=True)
153

154
  start_step = 0
155
  rng = jax.random.PRNGKey(FLAGS.random_seed)
156
  rng, init_rng = jax.random.split(rng)
157
  # It's possible that is supposed to be per device batch size
158
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
159
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
160

161
  m = models.Transformer(eval_config)
162
  initial_variables = jax.jit(m.init)(init_rng,
163
                                      jnp.ones(input_shape, jnp.float32),
164
                                      jnp.ones(target_shape, jnp.float32))
165

166
  # apply an optimizer to this tree
167
  optimizer_def = optim.Adam(
168
      FLAGS.learning_rate,
169
      beta1=0.9,
170
      beta2=0.98,
171
      eps=1e-9,
172
      weight_decay=FLAGS.weight_decay)
173
  optimizer = optimizer_def.create(initial_variables['params'])
174

175
  # We access model params only from optimizer below via optimizer.target.
176
  del initial_variables
177

178
  if FLAGS.restore_checkpoints:
179
    logging.info('Restoring checkpoint.')
180
    # If we have a pretrained model, use that. Else, just continue where leftoff
181
    model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
182
    optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
183
    # Grab last step.
184
    start_step = int(optimizer.state.step)
185

186
  writer = metric_writers.create_default_writer(
187
      FLAGS.model_dir, just_logging=jax.process_index() > 0)
188

189
  flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
190
             ]
191
  if flag_key:
192
    flag_key = flag_key[0]
193
    local_flags = {
194
        f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key]
195
    }
196
    writer.write_hparams(local_flags)
197

198
  # Replicate optimizer.
199
  optimizer = jax_utils.replicate(optimizer)
200

201
  learning_rate_fn = common.create_learning_rate_scheduler(
202
      base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps,
203
      steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step,
204
      finetune_lr=FLAGS.finetune_lr)
205

206
  # compile multidevice versions of train/eval/predict step and cache init fn.
207
  p_train_step = jax.pmap(
208
      functools.partial(
209
          train_util.train_step,
210
          config=train_config,
211
          learning_rate_fn=learning_rate_fn,
212
          label_smoothing=FLAGS.label_smoothing),
213
      axis_name='batch',
214
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
215
  p_eval_step = jax.pmap(
216
      functools.partial(train_util.eval_step, config=eval_config),
217
      axis_name='batch')
218
  p_init_cache = jax.pmap(
219
      functools.partial(
220
          train_util.initialize_cache,
221
          max_decode_len=FLAGS.max_predict_length,
222
          config=predict_config),
223
      axis_name='batch')
224
  p_pred_step = jax.pmap(
225
      functools.partial(
226
          train_util.predict_step,
227
          config=predict_config,
228
          beam_size=FLAGS.beam_size),
229
      axis_name='batch',
230
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant
231

232
  # Main Train Loop
233
  # ---------------------------------------------------------------------------
234

235
  # We init the first set of dropout PRNG keys, but update it afterwards inside
236
  # the main pmap"d training update for performance.
237
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
238
  del rng
239

240
  logging.info('Starting training loop.')
241
  hooks = []
242
  report_progress = periodic_actions.ReportProgress(
243
      num_train_steps=FLAGS.num_train_steps, writer=writer)
244
  if jax.process_index() == 0:
245
    hooks += [
246
        report_progress,
247
        periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5)
248
    ]
249
  train_metrics = []
250
  total_steps = start_step + FLAGS.num_train_steps
251
  if FLAGS.eval_only:
252
    total_steps = start_step + 1
253
  best_eval_loss = 1000
254
  curr_eval_loss = 1000
255
  eval_loss_history = []
256
  do_resample_data = False
257
  gradual_selection_size = FLAGS.data_selection_size
258
  eval_freq = FLAGS.eval_frequency
259
  with metric_writers.ensure_flushes(writer):
260
    for step in range(start_step, total_steps):
261
      is_last_step = step == total_steps - 1
262

263
      # Resample training data for gradual FT
264
      if do_resample_data:
265
        # resample data
266
        do_resample_data = False
267
        if eval_loss_history[-1] > eval_loss_history[-2]:
268
          gradual_selection_size = int(gradual_selection_size / .75)
269
        else:
270
          gradual_selection_size = int(.75 * gradual_selection_size)
271
        if gradual_selection_size < 500_000:
272
          eval_freq = int(gradual_selection_size) / 100
273

274
        train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
275
            dataset_name=FLAGS.dataset_name,
276
            eval_dataset_name=FLAGS.eval_dataset_name,
277
            shard_idx=jax.process_index(),
278
            shard_count=jax.process_count(),
279
            data_dir=FLAGS.data_dir,
280
            vocab_path=vocab_path,
281
            target_vocab_size=FLAGS.vocab_size,
282
            batch_size=FLAGS.batch_size,
283
            max_length=FLAGS.max_target_length,
284
            max_eval_length=FLAGS.max_eval_target_length,
285
            paracrawl_size=FLAGS.paracrawl_size,
286
            is_scores_path=FLAGS.is_scores_path,
287
            num_to_keep=gradual_selection_size,
288
            pseudo_path=FLAGS.pseudo_path,
289
            repeat_count=FLAGS.repeat_count,
290
            newscommentary_size=FLAGS.newscommentary_size,
291
            split_tokenizer=FLAGS.split_tokenizer)
292
        train_iter = iter(train_ds)
293
        logging.info('Decrease selection size to %d at step %d',
294
                     gradual_selection_size, step)
295

296
      # Shard data to devices and do a training step.
297
      if not FLAGS.eval_only:
298
        with jax.profiler.StepTraceAnnotation('train', step_num=step):
299
          try:
300
            batch = common_utils.shard(
301
                jax.tree_map(np.asarray, next(train_iter)))
302
            optimizer, metrics = p_train_step(
303
                optimizer, batch, dropout_rng=dropout_rngs)
304
            train_metrics.append(metrics)
305
          except StopIteration:
306
            is_last_step = True
307

308
      # Quick indication that training is happening.
309
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
310
      for h in hooks:
311
        h(step)
312

313
      # Periodic metric handling.
314
      if (step - start_step) % eval_freq == 0 or is_last_step:
315
        if not FLAGS.eval_only:
316
          with report_progress.timed('training_metrics'):
317
            logging.info('Gathering training metrics.')
318
            train_metrics = common_utils.get_metrics(train_metrics)
319
            lr = train_metrics.pop('learning_rate').mean()
320
            metrics_sums = jax.tree_map(jnp.sum, train_metrics)
321
            denominator = metrics_sums.pop('denominator')
322
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
323
            summary['learning_rate'] = lr
324
            summary = {'train_' + k: v for k, v in summary.items()}
325
            writer.write_scalars(step, summary)
326
            train_metrics = []
327

328
        if FLAGS.eval_only:
329
          p_eval_per_pos_step = jax.pmap(
330
              functools.partial(
331
                  train_util.eval_per_pos_step, config=eval_config),
332
              axis_name='batch')
333
          # Get per example loss
334
          loss_filename = FLAGS.model_dir + '/test_losses.csv'
335
          train_util.write_per_example_losses(
336
              p_eval_step=p_eval_per_pos_step,
337
              target=optimizer.target,
338
              eval_ds=eval_ds,
339
              num_eval_steps=FLAGS.num_eval_steps,
340
              loss_filename=loss_filename)
341
        else:
342
          with report_progress.timed('eval'):
343
            eval_results = train_util.evaluate(
344
                p_eval_step=p_eval_step,
345
                target=optimizer.target,
346
                eval_ds=eval_ds,
347
                num_eval_steps=FLAGS.num_eval_steps)
348
            curr_eval_loss = eval_results['loss']
349
            eval_loss_history.append(curr_eval_loss)
350
            if len(eval_loss_history) > 1:
351
              orig_loss = eval_loss_history[-2]
352
              percent_change = (orig_loss - curr_eval_loss) / orig_loss
353
              percent_change *= 100
354
              if percent_change < .1:
355
                do_resample_data = True
356
            writer.write_scalars(
357
                step, {'eval_' + k: v for k, v in eval_results.items()})
358

359
        if FLAGS.aux_eval_dataset:
360
          for aux_i, aux_eval_ds in enumerate(aux_datasets):
361
            with report_progress.timed('aux_eval'):
362
              eval_results = train_util.evaluate(
363
                  p_eval_step=p_eval_step,
364
                  target=optimizer.target,
365
                  eval_ds=aux_eval_ds,
366
                  num_eval_steps=FLAGS.num_eval_steps)
367
              writer.write_scalars(
368
                  step, {
369
                      'aux' + str(aux_i) + '_eval_' + k: v
370
                      for k, v in eval_results.items()
371
                  })
372

373
        if FLAGS.compute_bleu:
374
          with report_progress.timed('translate_and_bleu'):
375
            decode_file = FLAGS.model_dir + '/decodes.csv'
376
            exemplars, bleu_score = train_util.translate_and_calculate_bleu(
377
                p_pred_step=p_pred_step,
378
                p_init_cache=p_init_cache,
379
                target=optimizer.target,
380
                predict_ds=predict_ds,
381
                decode_tokens=decode_tokens,
382
                max_predict_length=FLAGS.max_predict_length,
383
                num_eval_steps=FLAGS.num_eval_steps,
384
                decode_file=decode_file if FLAGS.eval_only else '')
385
            writer.write_scalars(step, {'bleu': bleu_score})
386
            writer.write_texts(step, {'samples': exemplars})
387

388
      # Save a checkpoint on one host after every checkpoint_freq steps.
389
      save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or
390
                         is_last_step)
391
      if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
392
      ) == 0:
393
        if curr_eval_loss < best_eval_loss:  # only save better checkpoints
394
          best_eval_loss = curr_eval_loss
395
          with report_progress.timed('checkpoint'):
396
            checkpoints.save_checkpoint(
397
                FLAGS.model_dir, jax_utils.unreplicate(optimizer),
398
                step, keep=FLAGS.chkpts_to_keep, overwrite=True)
399

400
      if is_last_step:
401
        break
402

403

404
if __name__ == '__main__':
405
  app.run(main)
406

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.