google-research
450 строк · 14.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Machine Translation example.
17
18This script trains a Transformer on a WMT dataset.
19"""
20
21import csv22import functools23import os24import time25
26from absl import app27from absl import flags28from absl import logging29from flax import jax_utils30from flax import linen as nn31from flax import optim32from flax.training import checkpoints33from flax.training import common_utils34import jax35import jax.nn36import jax.numpy as jnp37import numpy as np38import tensorflow as tf39
40from data_selection.wmt import common41from data_selection.wmt import input_pipeline42from data_selection.wmt import models43
44LAYERNORM_ADAPTER = 'LayerNorm'45ENDCODE_DECODE_B5 = 'encoderdecoderblock_5'46PASSTHRU = ''47NONE = 'None'48FLAGS = flags.FLAGS49
50flags.DEFINE_string(51'model_dir', default=None,52help='Directory to store model data.')53
54flags.DEFINE_string(55'is_save_path', default=None,56help='Path to save is scores to.')57
58flags.DEFINE_string(59'is_score_filename', default=None,60help='Filename to save is scores to.')61
62flags.DEFINE_string(63'is_diff_name', default=None,64help='Filename to save is diff scores to.')65
66flags.DEFINE_string(67'base_log_loss_file', default=None,68help='Filename of log loss from base run.')69
70flags.DEFINE_string(71'pretrained_model_dir', default=None,72help='Directory of pretrained model data.')73
74flags.DEFINE_string(75'data_dir', default=None,76help='Tensorflow datasets directory.')77
78flags.DEFINE_string(79'vocab_path', default=None,80help='Path to load or store sentencepiece vocab file.')81
82flags.DEFINE_integer(83'vocab_size', default=32000,84help='Vocabulary size if `vocab_path` is not given.')85
86flags.DEFINE_string(87'dataset_name', default='wmt17_translate/de-en',88help='Name of TFDS translation dataset to use.')89
90flags.DEFINE_integer(91'batch_size', default=256,92help='Per host batch size for training.')93
94flags.DEFINE_float(95'learning_rate', default=0.0625,96help='Base learning rate.')97
98flags.DEFINE_float(99'label_smoothing', default=0.1,100help='Cross entropy loss label smoothing.')101
102flags.DEFINE_float(103'weight_decay', default=0.0,104help='Decay factor for AdamW style weight decay.')105
106flags.DEFINE_integer(107'max_target_length', default=256,108help='Maximum length cutoff for training examples.')109
110flags.DEFINE_integer(111'max_eval_target_length', default=256,112help='Maximum length cutoff for eval examples.')113
114flags.DEFINE_bool(115'share_embeddings', default=True,116help='Inputs and targets share embedding.')117
118flags.DEFINE_bool(119'logits_via_embedding', default=True,120help='Final logit transform uses embedding matrix transpose.')121
122flags.DEFINE_integer(123'num_layers', default=6,124help='Number of transformer layers.')125
126flags.DEFINE_integer(127'qkv_dim', default=1024,128help='Size of query/key/value for attention.')129
130flags.DEFINE_integer(131'emb_dim', default=1024,132help='Size of embeddings.')133
134flags.DEFINE_integer(135'mlp_dim', default=4096,136help='Size of the MLP.')137
138flags.DEFINE_integer(139'num_heads', default=16,140help='Number of attention heads.')141
142flags.DEFINE_float(143'dropout_rate', default=0.1,144help='Dropout rate.')145
146flags.DEFINE_float(147'attention_dropout_rate', default=0.1,148help='Attention dropout rate.')149
150flags.DEFINE_integer(151'random_seed', default=0,152help='Integer for PRNG random seed.')153
154
155flags.DEFINE_bool(156'save_checkpoints', default=True,157help='Whether to save model checkpoints.')158
159flags.DEFINE_bool(160'restore_checkpoints', default=True,161help='Whether to restore from existing model checkpoints.')162
163flags.DEFINE_bool(164'use_bfloat16', default=True,165help=('Use bfloat16 mixed precision training instead of float32.'))166
167flags.DEFINE_string(168'jax_backend_target', default=None,169help=('TPU grpc target for use with cloud TPUs.'170' e.g. grpc://192.168.0.2:8470'))171
172flags.DEFINE_integer(173'paracrawl_size', default=1200000,174help='Number of examples to sample from paracrawl.')175
176flags.DEFINE_enum(177'adapter', default=NONE, enum_values=[LAYERNORM_ADAPTER,178ENDCODE_DECODE_B5,179PASSTHRU,180NONE],181help='Whether to finetune only some parameters.')182flags.DEFINE_bool(183'split_tokenizer', default=False,184help='Separate tokenizer for each language.')185
186
187def compute_per_example_loss(logits,188targets,189weights=None,190label_smoothing=0.0):191"""Compute weighted cross entropy and entropy for log probs and targets.192
193Args:
194logits: [batch, length, num_classes] float array.
195targets: categorical targets [batch, length] int array.
196weights: None or array of shape [batch, length].
197label_smoothing: label smoothing constant, used to determine the on and
198off values.
199
200Returns:
201Tuple of scalar loss and batch normalizing factor.
202"""
203if logits.ndim != targets.ndim + 1:204raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %205(str(logits.shape), str(targets.shape)))206vocab_size = logits.shape[-1]207confidence = 1.0 - label_smoothing208low_confidence = (1.0 - confidence) / (vocab_size - 1)209normalizing_constant = -(210confidence * jnp.log(confidence) + (vocab_size - 1) *211low_confidence * jnp.log(low_confidence + 1e-20))212soft_targets = common_utils.onehot(213targets, vocab_size, on_value=confidence, off_value=low_confidence)214
215loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)216loss = loss - normalizing_constant217
218if weights is not None:219loss = loss * weights220
221return loss.sum(axis=-1)/ weights.sum(axis=-1)222
223
224def eval_for_is_step(params, batch, config, label_smoothing=0.0):225"""Calculate evaluation metrics on a batch."""226inputs, targets = batch['inputs'], batch['targets']227weights = jnp.where(targets > 0, 1.0, 0.0)228logits = models.Transformer(config).apply({'params': params}, inputs, targets)229losses = compute_per_example_loss(logits,230targets,231weights,232label_smoothing)233length = weights.sum(axis=-1)234return losses, length235
236
237def compute_is_scores(filename):238"""Compute IS scores for training data."""239
240# Make sure tf does not allocate gpu memory.241tf.config.experimental.set_visible_devices([], 'GPU')242
243if FLAGS.jax_backend_target:244jax.config.update('jax_xla_backend', 'tpu_driver')245jax.config.update('jax_backend_target', FLAGS.jax_backend_target)246
247# Number of local devices for this host.248n_devices = jax.local_device_count()249
250if jax.host_id() == 0:251tf.io.gfile.makedirs(FLAGS.model_dir)252
253if FLAGS.batch_size % n_devices:254raise ValueError('Batch size must be divisible by the number of devices')255
256vocab_path = FLAGS.vocab_path257if vocab_path is None:258vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')259tf.io.gfile.makedirs(os.path.split(vocab_path)[0])260
261# Load Dataset262print('Loading data')263logging.info('Initializing dataset.')264train_ds, (_, encoder_tgt) = input_pipeline.get_wmt_is_datasets(265n_devices=n_devices,266dataset_name=FLAGS.dataset_name,267shard_idx=jax.host_id(),268shard_count=jax.host_count(),269data_dir=FLAGS.data_dir,270vocab_path=vocab_path,271target_vocab_size=FLAGS.vocab_size,272batch_size=FLAGS.batch_size,273max_length=FLAGS.max_target_length,274paracrawl_size=FLAGS.paracrawl_size,275split_tokenizer=FLAGS.split_tokenizer)276print('Datasets created')277
278encoder = encoder_tgt279train_iter = iter(train_ds)280vocab_size = int(encoder.vocab_size())281print('data iterators created')282
283logging.info('Initializing model, optimizer, and step functions.')284# Build Model and Optimizer285# ---------------------------------------------------------------------------286eval_config = models.TransformerConfig(287vocab_size=vocab_size,288output_vocab_size=vocab_size,289share_embeddings=FLAGS.share_embeddings,290logits_via_embedding=FLAGS.logits_via_embedding,291dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,292emb_dim=FLAGS.emb_dim,293num_heads=FLAGS.num_heads,294num_layers=FLAGS.num_layers,295qkv_dim=FLAGS.qkv_dim,296mlp_dim=FLAGS.mlp_dim,297max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),298dropout_rate=FLAGS.dropout_rate,299attention_dropout_rate=FLAGS.attention_dropout_rate,300deterministic=True,301decode=False,302kernel_init=nn.initializers.xavier_uniform(),303bias_init=nn.initializers.normal(stddev=1e-6))304
305rng = jax.random.PRNGKey(FLAGS.random_seed)306rng, init_rng = jax.random.split(rng)307# It's possible that is supposed to be per device batch size308input_shape = (FLAGS.batch_size, FLAGS.max_target_length)309target_shape = (FLAGS.batch_size, FLAGS.max_target_length)310
311m = models.Transformer(eval_config)312initial_variables = jax.jit(m.init)(init_rng,313jnp.ones(input_shape, jnp.float32),314jnp.ones(target_shape, jnp.float32))315
316# apply an optimizer to this tree317optimizer_def = optim.Adam(318FLAGS.learning_rate,319beta1=0.9,320beta2=0.98,321eps=1e-9,322weight_decay=FLAGS.weight_decay)323optimizer = optimizer_def.create(initial_variables['params'])324
325# We access model params only from optimizer below via optimizer.target.326del initial_variables327
328if FLAGS.restore_checkpoints:329logging.info('Restoring checkpoint.')330# If we have a pretrained model, use that. Else, just continue where leftoff331model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir332# When loading a checkpoint trained with adapters (ie. frozen weights)333# restoring from the base optimizer fails. We catch this error and create334# the optimizer with frozen weights.335try:336optimizer = checkpoints.restore_checkpoint(model_path, optimizer)337# Grab last step.338except ValueError:339adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)340optimizer = optimizer_def.create(optimizer.target, focus=adapter)341optimizer = checkpoints.restore_checkpoint(model_path, optimizer)342
343else:344raise RuntimeError('Must restore checkpoint for IS')345
346if FLAGS.adapter != NONE and not isinstance(optimizer, optim.MultiOptimizer):347adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path)348optimizer = optimizer_def.create(optimizer.target, focus=adapter)349# Replicate optimizer.350optimizer = jax_utils.replicate(optimizer)351
352p_eval_step = jax.pmap(353functools.partial(354eval_for_is_step,355config=eval_config),356axis_name='batch')357
358logging.info('Start scoring loop.')359t_loop_start = time.time()360
361# Eval Metrics362logging.info('Gathering evaluation metrics.')363save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt'364length_fp = tf.io.gfile.GFile(save_file, 'w')365lengths_writer = csv.writer(length_fp)366
367save_file = FLAGS.is_save_path + '/' + filename + '.txt'368with tf.io.gfile.GFile(save_file, 'w') as fp:369writer = csv.writer(fp)370
371for batch_idx, eval_batch in enumerate(train_iter):372eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access373cur_pred_batch_size = eval_batch['inputs'].shape[0]374if cur_pred_batch_size % n_devices:375padded_size = int(376np.ceil(cur_pred_batch_size / n_devices) * n_devices)377eval_batch = jax.tree_map(378lambda x: common.pad_examples(x, padded_size), eval_batch) # pylint: disable=cell-var-from-loop379eval_batch = common_utils.shard(eval_batch)380losses, lengths = p_eval_step(optimizer.target, eval_batch)381if jax.host_id() == 0:382losses = common.tohost(losses)383lengths = common.tohost(lengths)384if cur_pred_batch_size % n_devices:385writer.writerow(losses[:cur_pred_batch_size])386lengths_writer.writerow(lengths[:cur_pred_batch_size])387else:388writer.writerow(losses)389lengths_writer.writerow(lengths)390
391if batch_idx % 500 == 0:392print('Batch', batch_idx)393print(time.time() - t_loop_start)394length_fp.close()395
396
397def main(_):398compute_is_scores(FLAGS.is_score_filename)399
400if FLAGS.base_log_loss_file:401beforefile = FLAGS.base_log_loss_file402afterfile = FLAGS.is_save_path + '/' + FLAGS.is_score_filename + '.txt'403before_scores = []404after_scores = []405with tf.io.gfile.GFile(beforefile, 'r') as f:406reader = csv.reader(f)407for row in reader:408before_scores.extend(row)409with tf.io.gfile.GFile(afterfile, 'r') as f:410reader = csv.reader(f)411for row in reader:412after_scores.extend(row)413
414beforefile = beforefile.replace('.txt', '-lengths.txt')415afterfile = afterfile.replace('.txt', '-lengths.txt')416before_length = []417after_length = []418with tf.io.gfile.GFile(beforefile, 'r') as f:419reader = csv.reader(f)420for row in reader:421before_length.extend(row)422with tf.io.gfile.GFile(afterfile, 'r') as f:423reader = csv.reader(f)424for row in reader:425after_length.extend(row)426
427diff = [float(a)-float(b) for (a, b) in zip(after_scores, before_scores)]428after_scores = [float(a) for a in after_scores]429before_scores = [float(a) for a in before_scores]430after_length = [float(a) for a in after_length]431before_length = [float(b) for b in before_length]432
433for a, b in zip(before_length, after_length):434assert a == b435
436is_diff_name = FLAGS.is_save_path + '/' + FLAGS.is_diff_name437with tf.io.gfile.GFile(is_diff_name, 'w') as f:438writer = csv.writer(f)439for val in diff:440writer.writerow([val])441
442with tf.io.gfile.GFile(443is_diff_name.replace('.csv', '_length.csv'), 'w') as f:444writer = csv.writer(f)445for val in after_length:446writer.writerow([int(val)])447
448
449if __name__ == '__main__':450app.run(main)451