google-research

Форк
0
450 строк · 14.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Machine Translation example.
17

18
This script trains a Transformer on a WMT dataset.
19
"""
20

21
import csv
22
import functools
23
import os
24
import time
25

26
from absl import app
27
from absl import flags
28
from absl import logging
29
from flax import jax_utils
30
from flax import linen as nn
31
from flax import optim
32
from flax.training import checkpoints
33
from flax.training import common_utils
34
import jax
35
import jax.nn
36
import jax.numpy as jnp
37
import numpy as np
38
import tensorflow as tf
39

40
from data_selection.wmt import common
41
from data_selection.wmt import input_pipeline
42
from data_selection.wmt import models
43

44
LAYERNORM_ADAPTER = 'LayerNorm'
45
ENDCODE_DECODE_B5 = 'encoderdecoderblock_5'
46
PASSTHRU = ''
47
NONE = 'None'
48
FLAGS = flags.FLAGS
49

50
flags.DEFINE_string(
51
    'model_dir', default=None,
52
    help='Directory to store model data.')
53

54
flags.DEFINE_string(
55
    'is_save_path', default=None,
56
    help='Path to save is scores to.')
57

58
flags.DEFINE_string(
59
    'is_score_filename', default=None,
60
    help='Filename to save is scores to.')
61

62
flags.DEFINE_string(
63
    'is_diff_name', default=None,
64
    help='Filename to save is diff scores to.')
65

66
flags.DEFINE_string(
67
    'base_log_loss_file', default=None,
68
    help='Filename of log loss from base run.')
69

70
flags.DEFINE_string(
71
    'pretrained_model_dir', default=None,
72
    help='Directory of pretrained model data.')
73

74
flags.DEFINE_string(
75
    'data_dir', default=None,
76
    help='Tensorflow datasets directory.')
77

78
flags.DEFINE_string(
79
    'vocab_path', default=None,
80
    help='Path to load or store sentencepiece vocab file.')
81

82
flags.DEFINE_integer(
83
    'vocab_size', default=32000,
84
    help='Vocabulary size if `vocab_path` is not given.')
85

86
flags.DEFINE_string(
87
    'dataset_name', default='wmt17_translate/de-en',
88
    help='Name of TFDS translation dataset to use.')
89

90
flags.DEFINE_integer(
91
    'batch_size', default=256,
92
    help='Per host batch size for training.')
93

94
flags.DEFINE_float(
95
    'learning_rate', default=0.0625,
96
    help='Base learning rate.')
97

98
flags.DEFINE_float(
99
    'label_smoothing', default=0.1,
100
    help='Cross entropy loss label smoothing.')
101

102
flags.DEFINE_float(
103
    'weight_decay', default=0.0,
104
    help='Decay factor for AdamW style weight decay.')
105

106
flags.DEFINE_integer(
107
    'max_target_length', default=256,
108
    help='Maximum length cutoff for training examples.')
109

110
flags.DEFINE_integer(
111
    'max_eval_target_length', default=256,
112
    help='Maximum length cutoff for eval examples.')
113

114
flags.DEFINE_bool(
115
    'share_embeddings', default=True,
116
    help='Inputs and targets share embedding.')
117

118
flags.DEFINE_bool(
119
    'logits_via_embedding', default=True,
120
    help='Final logit transform uses embedding matrix transpose.')
121

122
flags.DEFINE_integer(
123
    'num_layers', default=6,
124
    help='Number of transformer layers.')
125

126
flags.DEFINE_integer(
127
    'qkv_dim', default=1024,
128
    help='Size of query/key/value for attention.')
129

130
flags.DEFINE_integer(
131
    'emb_dim', default=1024,
132
    help='Size of embeddings.')
133

134
flags.DEFINE_integer(
135
    'mlp_dim', default=4096,
136
    help='Size of the MLP.')
137

138
flags.DEFINE_integer(
139
    'num_heads', default=16,
140
    help='Number of attention heads.')
141

142
flags.DEFINE_float(
143
    'dropout_rate', default=0.1,
144
    help='Dropout rate.')
145

146
flags.DEFINE_float(
147
    'attention_dropout_rate', default=0.1,
148
    help='Attention dropout rate.')
149

150
flags.DEFINE_integer(
151
    'random_seed', default=0,
152
    help='Integer for PRNG random seed.')
153

154

155
flags.DEFINE_bool(
156
    'save_checkpoints', default=True,
157
    help='Whether to save model checkpoints.')
158

159
flags.DEFINE_bool(
160
    'restore_checkpoints', default=True,
161
    help='Whether to restore from existing model checkpoints.')
162

163
flags.DEFINE_bool(
164
    'use_bfloat16', default=True,
165
    help=('Use bfloat16 mixed precision training instead of float32.'))
166

167
flags.DEFINE_string(
168
    'jax_backend_target', default=None,
169
    help=('TPU grpc target for use with cloud TPUs.'
170
          ' e.g. grpc://192.168.0.2:8470'))
171

172
flags.DEFINE_integer(
173
    'paracrawl_size', default=1200000,
174
    help='Number of examples to sample from paracrawl.')
175

176
flags.DEFINE_enum(
177
    'adapter', default=NONE, enum_values=[LAYERNORM_ADAPTER,
178
                                          ENDCODE_DECODE_B5,
179
                                          PASSTHRU,
180
                                          NONE],
181
    help='Whether to finetune only some parameters.')
182
flags.DEFINE_bool(
183
    'split_tokenizer', default=False,
184
    help='Separate tokenizer for each language.')
185

186

187
def compute_per_example_loss(logits,
188
                             targets,
189
                             weights=None,
190
                             label_smoothing=0.0):
191
  """Compute weighted cross entropy and entropy for log probs and targets.
192

193
  Args:
194
   logits: [batch, length, num_classes] float array.
195
   targets: categorical targets [batch, length] int array.
196
   weights: None or array of shape [batch, length].
197
   label_smoothing: label smoothing constant, used to determine the on and
198
     off values.
199

200
  Returns:
201
    Tuple of scalar loss and batch normalizing factor.
202
  """
203
  if logits.ndim != targets.ndim + 1:
204
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
205
                     (str(logits.shape), str(targets.shape)))
206
  vocab_size = logits.shape[-1]
207
  confidence = 1.0 - label_smoothing
208
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
209
  normalizing_constant = -(
210
      confidence * jnp.log(confidence) + (vocab_size - 1) *
211
      low_confidence * jnp.log(low_confidence + 1e-20))
212
  soft_targets = common_utils.onehot(
213
      targets, vocab_size, on_value=confidence, off_value=low_confidence)
214

215
  loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
216
  loss = loss - normalizing_constant
217

218
  if weights is not None:
219
    loss = loss * weights
220

221
  return loss.sum(axis=-1)/ weights.sum(axis=-1)
222

223

224
def eval_for_is_step(params, batch, config, label_smoothing=0.0):
225
  """Calculate evaluation metrics on a batch."""
226
  inputs, targets = batch['inputs'], batch['targets']
227
  weights = jnp.where(targets > 0, 1.0, 0.0)
228
  logits = models.Transformer(config).apply({'params': params}, inputs, targets)
229
  losses = compute_per_example_loss(logits,
230
                                    targets,
231
                                    weights,
232
                                    label_smoothing)
233
  length = weights.sum(axis=-1)
234
  return losses, length
235

236

237
def compute_is_scores(filename):
238
  """Compute IS scores for training data."""
239

240
  # Make sure tf does not allocate gpu memory.
241
  tf.config.experimental.set_visible_devices([], 'GPU')
242

243
  if FLAGS.jax_backend_target:
244
    jax.config.update('jax_xla_backend', 'tpu_driver')
245
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
246

247
  # Number of local devices for this host.
248
  n_devices = jax.local_device_count()
249

250
  if jax.host_id() == 0:
251
    tf.io.gfile.makedirs(FLAGS.model_dir)
252

253
  if FLAGS.batch_size % n_devices:
254
    raise ValueError('Batch size must be divisible by the number of devices')
255

256
  vocab_path = FLAGS.vocab_path
257
  if vocab_path is None:
258
    vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
259
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])
260

261
  # Load Dataset
262
  print('Loading data')
263
  logging.info('Initializing dataset.')
264
  train_ds, (_, encoder_tgt) = input_pipeline.get_wmt_is_datasets(
265
      n_devices=n_devices,
266
      dataset_name=FLAGS.dataset_name,
267
      shard_idx=jax.host_id(),
268
      shard_count=jax.host_count(),
269
      data_dir=FLAGS.data_dir,
270
      vocab_path=vocab_path,
271
      target_vocab_size=FLAGS.vocab_size,
272
      batch_size=FLAGS.batch_size,
273
      max_length=FLAGS.max_target_length,
274
      paracrawl_size=FLAGS.paracrawl_size,
275
      split_tokenizer=FLAGS.split_tokenizer)
276
  print('Datasets created')
277

278
  encoder = encoder_tgt
279
  train_iter = iter(train_ds)
280
  vocab_size = int(encoder.vocab_size())
281
  print('data iterators created')
282

283
  logging.info('Initializing model, optimizer, and step functions.')
284
  # Build Model and Optimizer
285
  # ---------------------------------------------------------------------------
286
  eval_config = models.TransformerConfig(
287
      vocab_size=vocab_size,
288
      output_vocab_size=vocab_size,
289
      share_embeddings=FLAGS.share_embeddings,
290
      logits_via_embedding=FLAGS.logits_via_embedding,
291
      dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
292
      emb_dim=FLAGS.emb_dim,
293
      num_heads=FLAGS.num_heads,
294
      num_layers=FLAGS.num_layers,
295
      qkv_dim=FLAGS.qkv_dim,
296
      mlp_dim=FLAGS.mlp_dim,
297
      max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
298
      dropout_rate=FLAGS.dropout_rate,
299
      attention_dropout_rate=FLAGS.attention_dropout_rate,
300
      deterministic=True,
301
      decode=False,
302
      kernel_init=nn.initializers.xavier_uniform(),
303
      bias_init=nn.initializers.normal(stddev=1e-6))
304

305
  rng = jax.random.PRNGKey(FLAGS.random_seed)
306
  rng, init_rng = jax.random.split(rng)
307
  # It's possible that is supposed to be per device batch size
308
  input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
309
  target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
310

311
  m = models.Transformer(eval_config)
312
  initial_variables = jax.jit(m.init)(init_rng,
313
                                      jnp.ones(input_shape, jnp.float32),
314
                                      jnp.ones(target_shape, jnp.float32))
315

316
  # apply an optimizer to this tree
317
  optimizer_def = optim.Adam(
318
      FLAGS.learning_rate,
319
      beta1=0.9,
320
      beta2=0.98,
321
      eps=1e-9,
322
      weight_decay=FLAGS.weight_decay)
323
  optimizer = optimizer_def.create(initial_variables['params'])
324

325
  # We access model params only from optimizer below via optimizer.target.
326
  del initial_variables
327

328
  if FLAGS.restore_checkpoints:
329
    logging.info('Restoring checkpoint.')
330
    # If we have a pretrained model, use that. Else, just continue where leftoff
331
    model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
332
    # When loading a checkpoint trained with adapters (ie. frozen weights)
333
    # restoring from the base optimizer fails. We catch this error and create
334
    # the optimizer with frozen weights.
335
    try:
336
      optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
337
      # Grab last step.
338
    except ValueError:
339
      adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)
340
      optimizer = optimizer_def.create(optimizer.target, focus=adapter)
341
      optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
342

343
  else:
344
    raise RuntimeError('Must restore checkpoint for IS')
345

346
  if FLAGS.adapter != NONE and not isinstance(optimizer, optim.MultiOptimizer):
347
    adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)
348
    optimizer = optimizer_def.create(optimizer.target, focus=adapter)
349
  # Replicate optimizer.
350
  optimizer = jax_utils.replicate(optimizer)
351

352
  p_eval_step = jax.pmap(
353
      functools.partial(
354
          eval_for_is_step,
355
          config=eval_config),
356
      axis_name='batch')
357

358
  logging.info('Start scoring loop.')
359
  t_loop_start = time.time()
360

361
  # Eval Metrics
362
  logging.info('Gathering evaluation metrics.')
363
  save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt'
364
  length_fp = tf.io.gfile.GFile(save_file, 'w')
365
  lengths_writer = csv.writer(length_fp)
366

367
  save_file = FLAGS.is_save_path + '/' + filename + '.txt'
368
  with tf.io.gfile.GFile(save_file, 'w') as fp:
369
    writer = csv.writer(fp)
370

371
    for batch_idx, eval_batch in enumerate(train_iter):
372
      eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
373
      cur_pred_batch_size = eval_batch['inputs'].shape[0]
374
      if cur_pred_batch_size % n_devices:
375
        padded_size = int(
376
            np.ceil(cur_pred_batch_size / n_devices) * n_devices)
377
        eval_batch = jax.tree_map(
378
            lambda x: common.pad_examples(x, padded_size), eval_batch)  # pylint: disable=cell-var-from-loop
379
      eval_batch = common_utils.shard(eval_batch)
380
      losses, lengths = p_eval_step(optimizer.target, eval_batch)
381
      if jax.host_id() == 0:
382
        losses = common.tohost(losses)
383
        lengths = common.tohost(lengths)
384
        if cur_pred_batch_size % n_devices:
385
          writer.writerow(losses[:cur_pred_batch_size])
386
          lengths_writer.writerow(lengths[:cur_pred_batch_size])
387
        else:
388
          writer.writerow(losses)
389
          lengths_writer.writerow(lengths)
390

391
      if batch_idx % 500 == 0:
392
        print('Batch', batch_idx)
393
        print(time.time() - t_loop_start)
394
  length_fp.close()
395

396

397
def main(_):
398
  compute_is_scores(FLAGS.is_score_filename)
399

400
  if FLAGS.base_log_loss_file:
401
    beforefile = FLAGS.base_log_loss_file
402
    afterfile = FLAGS.is_save_path + '/' + FLAGS.is_score_filename + '.txt'
403
    before_scores = []
404
    after_scores = []
405
    with tf.io.gfile.GFile(beforefile, 'r') as f:
406
      reader = csv.reader(f)
407
      for row in reader:
408
        before_scores.extend(row)
409
    with tf.io.gfile.GFile(afterfile, 'r') as f:
410
      reader = csv.reader(f)
411
      for row in reader:
412
        after_scores.extend(row)
413

414
    beforefile = beforefile.replace('.txt', '-lengths.txt')
415
    afterfile = afterfile.replace('.txt', '-lengths.txt')
416
    before_length = []
417
    after_length = []
418
    with tf.io.gfile.GFile(beforefile, 'r') as f:
419
      reader = csv.reader(f)
420
      for row in reader:
421
        before_length.extend(row)
422
    with tf.io.gfile.GFile(afterfile, 'r') as f:
423
      reader = csv.reader(f)
424
      for row in reader:
425
        after_length.extend(row)
426

427
    diff = [float(a)-float(b) for (a, b) in zip(after_scores, before_scores)]
428
    after_scores = [float(a) for a in after_scores]
429
    before_scores = [float(a) for a in before_scores]
430
    after_length = [float(a) for a in after_length]
431
    before_length = [float(b) for b in before_length]
432

433
    for a, b in zip(before_length, after_length):
434
      assert a == b
435

436
    is_diff_name = FLAGS.is_save_path + '/' + FLAGS.is_diff_name
437
    with tf.io.gfile.GFile(is_diff_name, 'w') as f:
438
      writer = csv.writer(f)
439
      for val in diff:
440
        writer.writerow([val])
441

442
    with tf.io.gfile.GFile(
443
        is_diff_name.replace('.csv', '_length.csv'), 'w') as f:
444
      writer = csv.writer(f)
445
      for val in after_length:
446
        writer.writerow([int(val)])
447

448

449
if __name__ == '__main__':
450
  app.run(main)
451

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.