google-research

Форк
0
409 строк · 15.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Gradual training for NMT.
17

18
This script trains a Transformer on a WMT dataset.
19
Gradual training refers to the periodic decrease in the
20
out of domain dataset size. This is similar to the
21
gradual finetining proposed in dynamic data selection.
22
"""
23

24
# pytype: disable=wrong-arg-count
25
# pytype: disable=attribute-error
26

27
import functools
28
import os
29

30
from absl import app
31
from absl import flags
32
from absl import logging
33
from clu import metric_writers
34
from clu import periodic_actions
35
from flax import jax_utils
36
from flax import linen as nn
37
from flax import optim
38
from flax.training import checkpoints
39
from flax.training import common_utils
40
import jax
41
import jax.numpy as jnp
42
import numpy as np
43
import tensorflow as tf
44

45
from data_selection.wmt import common
46
from data_selection.wmt import decode
47
from data_selection.wmt import input_pipeline
48
from data_selection.wmt import models
49
from data_selection.wmt import train_util
50

51
FLAGS = flags.FLAGS
52
flags.adopt_module_key_flags(train_util)
53

54

55
def main(argv):
56
  if len(argv) > 1:
57
    raise app.UsageError('Too many command-line arguments.')
58

59
  # Make sure tf does not allocate gpu memory.
60
  tf.config.experimental.set_visible_devices([], 'GPU')
61

62
  if FLAGS.jax_backend_target:
63
    jax.config.update('jax_xla_backend', 'tpu_driver')
64
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
65

66
  # Number of local devices for this host.
67
  n_devices = jax.local_device_count()
68

69
  if jax.process_index() == 0:
70
    tf.io.gfile.makedirs(FLAGS.model_dir)
71

72
  if FLAGS.batch_size % n_devices:
73
    raise ValueError('Batch size must be divisible by the number of devices')
74

75
  vocab_path = FLAGS.vocab_path
76
  if vocab_path is None:
77
    vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
78
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])
79

80
  # Load Dataset
81
  # ---------------------------------------------------------------------------
82
  logging.info('Initializing dataset.')
83
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
84
      dataset_name=FLAGS.dataset_name,
85
      eval_dataset_name=FLAGS.eval_dataset_name,
86
      shard_idx=jax.process_index(),
87
      shard_count=jax.process_count(),
88
      data_dir=FLAGS.data_dir,
89
      vocab_path=vocab_path,
90
      target_vocab_size=FLAGS.vocab_size,
91
      batch_size=FLAGS.batch_size,
92
      max_length=FLAGS.max_target_length,
93
      max_eval_length=FLAGS.max_eval_target_length,
94
      paracrawl_size=FLAGS.paracrawl_size,
95
      is_scores_path=FLAGS.is_scores_path,
96
      num_to_keep=FLAGS.data_selection_size,
97
      pseudo_path=FLAGS.pseudo_path,
98
      repeat_count=FLAGS.repeat_count,
99
      newscommentary_size=FLAGS.newscommentary_size,
100
      split_tokenizer=FLAGS.split_tokenizer)
101

102
  if FLAGS.aux_eval_dataset:
103
    aux_datasets = []
104
    aux_names = FLAGS.aux_eval_dataset.split(',')
105
    for name in aux_names:
106
      _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
107
          dataset_name=name,
108
          eval_dataset_name=None,
109
          shard_idx=jax.process_index(),
110
          shard_count=jax.process_count(),
111
          data_dir=FLAGS.data_dir,
112
          vocab_path=vocab_path,
113
          target_vocab_size=FLAGS.vocab_size,
114
          batch_size=FLAGS.batch_size,
115
          max_length=FLAGS.max_target_length,
116
          max_eval_length=FLAGS.max_eval_target_length,
117
          paracrawl_size=FLAGS.paracrawl_size,
118
          is_scores_path=FLAGS.is_scores_path,
119
          num_to_keep=FLAGS.data_selection_size,
120
          pseudo_path=FLAGS.pseudo_path,
121
          repeat_count=FLAGS.repeat_count,
122
          newscommentary_size=FLAGS.newscommentary_size)
123
      aux_datasets.append(aux_eval_ds)
124

125
  train_iter = iter(train_ds)
126
  vocab_size = int(encoder.vocab_size())
127
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
128

129
  def decode_tokens(toks):
130
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
131
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')
132

133
  logging.info('Initializing model, optimizer, and step functions.')
134

135
  # Build Model and Optimizer
136
  # ---------------------------------------------------------------------------
137
  train_config = models.TransformerConfig(
138
      vocab_size=vocab_size,
139
      output_vocab_size=vocab_size,
140
      share_embeddings=FLAGS.share_embeddings,
141
      logits_via_embedding=FLAGS.logits_via_embedding,
142
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
143
      emb_dim=FLAGS.emb_dim,
144
      num_heads=FLAGS.num_heads,
145
      num_layers=FLAGS.num_layers,
146
      qkv_dim=FLAGS.qkv_dim,
147
      mlp_dim=FLAGS.mlp_dim,
148
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
149
      dropout_rate=FLAGS.dropout_rate,
150
      attention_dropout_rate=FLAGS.attention_dropout_rate,
151
      deterministic=False,
152
      decode=False,
153
      kernel_init=nn.initializers.xavier_uniform(),
154
      bias_init=nn.initializers.normal(stddev=1e-6))
155
  eval_config = train_config.replace(deterministic=True)
156
  predict_config = train_config.replace(deterministic=True, decode=True)
157

158
  start_step = 0
159
  rng = jax.random.PRNGKey(FLAGS.random_seed)
160
  rng, init_rng = jax.random.split(rng)
161
  # It's possible that is supposed to be per device batch size
162
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
163
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
164

165
  m = models.Transformer(eval_config)
166
  initial_variables = jax.jit(m.init)(init_rng,
167
                                      jnp.ones(input_shape, jnp.float32),
168
                                      jnp.ones(target_shape, jnp.float32))
169

170
  # apply an optimizer to this tree
171
  optimizer_def = optim.Adam(
172
      FLAGS.learning_rate,
173
      beta1=0.9,
174
      beta2=0.98,
175
      eps=1e-9,
176
      weight_decay=FLAGS.weight_decay)
177
  optimizer = optimizer_def.create(initial_variables['params'])
178

179
  # We access model params only from optimizer below via optimizer.target.
180
  del initial_variables
181

182
  if FLAGS.restore_checkpoints:
183
    logging.info('Restoring checkpoint.')
184
    # If we have a pretrained model, use that. Else, just continue where leftoff
185
    model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
186
    optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
187
    # Grab last step.
188
    start_step = int(optimizer.state.step)
189

190
  writer = metric_writers.create_default_writer(
191
      FLAGS.model_dir, just_logging=jax.process_index() > 0)
192

193
  flag_key = [k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
194
             ]
195
  if flag_key:
196
    flag_key = flag_key[0]
197
    local_flags = {
198
        f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key]
199
    }
200
    writer.write_hparams(local_flags)
201

202
  # Replicate optimizer.
203
  optimizer = jax_utils.replicate(optimizer)
204

205
  learning_rate_fn = common.create_learning_rate_scheduler(
206
      base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps,
207
      steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step,
208
      finetune_lr=FLAGS.finetune_lr)
209

210
  # compile multidevice versions of train/eval/predict step and cache init fn.
211
  p_train_step = jax.pmap(
212
      functools.partial(
213
          train_util.train_step,
214
          config=train_config,
215
          learning_rate_fn=learning_rate_fn,
216
          label_smoothing=FLAGS.label_smoothing),
217
      axis_name='batch',
218
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
219
  p_eval_step = jax.pmap(
220
      functools.partial(train_util.eval_step, config=eval_config),
221
      axis_name='batch')
222
  p_init_cache = jax.pmap(
223
      functools.partial(
224
          train_util.initialize_cache,
225
          max_decode_len=FLAGS.max_predict_length,
226
          config=predict_config),
227
      axis_name='batch')
228
  p_pred_step = jax.pmap(
229
      functools.partial(
230
          train_util.predict_step,
231
          config=predict_config,
232
          beam_size=FLAGS.beam_size),
233
      axis_name='batch',
234
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant
235

236
  # Main Train Loop
237
  # ---------------------------------------------------------------------------
238

239
  # We init the first set of dropout PRNG keys, but update it afterwards inside
240
  # the main pmap"d training update for performance.
241
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
242
  del rng
243

244
  logging.info('Starting training loop.')
245
  hooks = []
246
  report_progress = periodic_actions.ReportProgress(
247
      num_train_steps=FLAGS.num_train_steps, writer=writer)
248
  if jax.process_index() == 0:
249
    hooks += [
250
        report_progress,
251
        periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5)
252
    ]
253
  train_metrics = []
254
  total_steps = start_step + FLAGS.num_train_steps
255
  if FLAGS.eval_only:
256
    total_steps = start_step + 1
257
  best_eval_loss = 1000
258
  curr_eval_loss = 1000
259
  eval_loss_history = []
260
  last_eval_step = 0
261
  do_resample_data = False
262
  gradual_selection_size = FLAGS.data_selection_size
263
  dynamic_eval_freq = FLAGS.eval_frequency
264
  with metric_writers.ensure_flushes(writer):
265
    for step in range(start_step, total_steps):
266
      is_last_step = step == total_steps - 1
267

268
      # Resample training data for gradual FT
269
      if do_resample_data:
270
        # resample data
271
        do_resample_data = False
272
        gradual_selection_size *= .7
273
        dynamic_eval_freq = int(gradual_selection_size / 1000 / 4)
274

275
        train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
276
            dataset_name=FLAGS.dataset_name,
277
            eval_dataset_name=FLAGS.eval_dataset_name,
278
            shard_idx=jax.process_index(),
279
            shard_count=jax.process_count(),
280
            data_dir=FLAGS.data_dir,
281
            vocab_path=vocab_path,
282
            target_vocab_size=FLAGS.vocab_size,
283
            batch_size=FLAGS.batch_size,
284
            max_length=FLAGS.max_target_length,
285
            max_eval_length=FLAGS.max_eval_target_length,
286
            paracrawl_size=FLAGS.paracrawl_size,
287
            is_scores_path=FLAGS.is_scores_path,
288
            num_to_keep=int(gradual_selection_size),
289
            pseudo_path=FLAGS.pseudo_path,
290
            repeat_count=FLAGS.repeat_count,
291
            newscommentary_size=FLAGS.newscommentary_size,
292
            split_tokenizer=FLAGS.split_tokenizer)
293
        train_iter = iter(train_ds)
294

295
      # Shard data to devices and do a training step.
296
      if not FLAGS.eval_only:
297
        logging.info('Doing Training.')
298
        with jax.profiler.StepTraceAnnotation('train', step_num=step):
299
          try:
300
            batch = common_utils.shard(
301
                jax.tree_map(np.asarray, next(train_iter)))
302
            optimizer, metrics = p_train_step(
303
                optimizer, batch, dropout_rng=dropout_rngs)
304
            train_metrics.append(metrics)
305
          except StopIteration:
306
            is_last_step = True
307

308
      # Quick indication that training is happening.
309
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
310
      for h in hooks:
311
        h(step)
312

313
      # Periodic metric handling.
314
      if (step - start_step) % dynamic_eval_freq == 0 or is_last_step:
315
        if not FLAGS.eval_only:
316
          with report_progress.timed('training_metrics'):
317
            logging.info('Gathering training metrics.')
318
            train_metrics = common_utils.get_metrics(train_metrics)
319
            lr = train_metrics.pop('learning_rate').mean()
320
            metrics_sums = jax.tree_map(jnp.sum, train_metrics)
321
            denominator = metrics_sums.pop('denominator')
322
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
323
            summary['learning_rate'] = lr
324
            summary = {'train_' + k: v for k, v in summary.items()}
325
            writer.write_scalars(step, summary)
326
            train_metrics = []
327

328
        if FLAGS.eval_only:
329
          p_eval_per_pos_step = jax.pmap(
330
              functools.partial(
331
                  train_util.eval_per_pos_step, config=eval_config),
332
              axis_name='batch')
333
          # Get per example loss
334
          loss_filename = FLAGS.model_dir + '/test_losses.csv'
335
          train_util.write_per_example_losses(
336
              p_eval_step=p_eval_per_pos_step,
337
              target=optimizer.target,
338
              eval_ds=eval_ds,
339
              num_eval_steps=FLAGS.num_eval_steps,
340
              loss_filename=loss_filename)
341
        else:
342
          with report_progress.timed('eval'):
343
            eval_results = train_util.evaluate(
344
                p_eval_step=p_eval_step,
345
                target=optimizer.target,
346
                eval_ds=eval_ds,
347
                num_eval_steps=FLAGS.num_eval_steps)
348
            curr_eval_loss = eval_results['loss']
349
            eval_loss_history.append(curr_eval_loss)
350
            if len(eval_loss_history) > 1:
351
              improvement_rate = 0.000004
352
              orig_loss = eval_loss_history[-2]
353
              true_improvement = orig_loss - curr_eval_loss
354
              expected_improvement = (step - last_eval_step) * improvement_rate
355
              # percent_change = (orig_loss - curr_eval_loss) / orig_loss
356
              # percent_change *= 100
357
              if true_improvement < expected_improvement:  # percent_change<.1:
358
                do_resample_data = True
359
            last_eval_step = step
360
            writer.write_scalars(
361
                step, {'eval_' + k: v for k, v in eval_results.items()})
362

363
        if FLAGS.aux_eval_dataset:
364
          for aux_i, aux_eval_ds in enumerate(aux_datasets):
365
            with report_progress.timed('aux_eval'):
366
              eval_results = train_util.evaluate(
367
                  p_eval_step=p_eval_step,
368
                  target=optimizer.target,
369
                  eval_ds=aux_eval_ds,
370
                  num_eval_steps=FLAGS.num_eval_steps)
371
              writer.write_scalars(
372
                  step, {
373
                      'aux' + str(aux_i) + '_eval_' + k: v
374
                      for k, v in eval_results.items()
375
                  })
376

377
        if FLAGS.compute_bleu:
378
          with report_progress.timed('translate_and_bleu'):
379
            decode_file = FLAGS.model_dir + '/decodes.csv'
380
            exemplars, bleu_score = train_util.translate_and_calculate_bleu(
381
                p_pred_step=p_pred_step,
382
                p_init_cache=p_init_cache,
383
                target=optimizer.target,
384
                predict_ds=predict_ds,
385
                decode_tokens=decode_tokens,
386
                max_predict_length=FLAGS.max_predict_length,
387
                num_eval_steps=FLAGS.num_eval_steps,
388
                decode_file=decode_file if FLAGS.eval_only else '')
389
            writer.write_scalars(step, {'bleu': bleu_score})
390
            writer.write_texts(step, {'samples': exemplars})
391

392
      # Save a checkpoint on one host after every checkpoint_freq steps.
393
      save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or
394
                         is_last_step)
395
      if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
396
      ) == 0:
397
        if curr_eval_loss < best_eval_loss:  # only save better checkpoints
398
          best_eval_loss = curr_eval_loss
399
          with report_progress.timed('checkpoint'):
400
            checkpoints.save_checkpoint(
401
                FLAGS.model_dir, jax_utils.unreplicate(optimizer),
402
                step, keep=FLAGS.chkpts_to_keep, overwrite=True)
403

404
      if is_last_step:
405
        break
406

407

408
if __name__ == '__main__':
409
  app.run(main)
410

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.