google-research
191 строка · 5.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Transformer LM trainer."""
17
18import functools19
20from clu import periodic_actions21from flax.training import common_utils22from flax.training import train_state23import jax24from jax import numpy as jnp25import numpy as np26import optax27
28from sudoku_gpt import model29
30
31def lr_scheduler(32n_tokens, learning_rate, warmup_tokens, final_tokens, config33):34"""Learning rate scheduler, adapted from Mikhail Grankin."""35
36# Decay the learning rate based on our progress.37progress = (n_tokens - warmup_tokens) / max(381,39final_tokens - warmup_tokens,40)41lr_mult = jnp.where(42n_tokens < warmup_tokens,43# Linear warmup.44n_tokens / jnp.fmax(1, warmup_tokens),45# Cosine learning rate decay.46jnp.fmax(config.end_lr_factor, 0.5 * (1.0 + jnp.cos(np.pi * progress))),47)48return learning_rate * lr_mult49
50
51def train_step(state, batch, config, hyperparams, learning_rate_fn,52dropout_rng=None):53"""One step of the training loop.54
55Args:
56state: train state.
57batch: input batch
58config: experiment config
59hyperparams: hyperparameter dictionary
60learning_rate_fn: learning rate function
61dropout_rng: rng to be used for dropout
62
63Returns:
64A new train state, train metrics and computed model predictions.
65"""
66
67inputs = batch[:, :-1]68label = batch[:, 1:]69
70dropout_rng = jax.random.fold_in(dropout_rng, state.step)71dropout_rng_dict = {"dropout": dropout_rng}72
73def loss_fn(params):74corrupted_inputs = inputs75pred_logits = model.TransformerLMHeadModel(config).apply(76{"params": params}, corrupted_inputs, rngs=dropout_rng_dict)77
78label_one_hot = jax.nn.one_hot(label, num_classes=config.vocab_size)79
80assert label_one_hot.shape == pred_logits.shape, ("one hot label shape",81label_one_hot.shape,82label.shape,83pred_logits.shape)84if "sudoku" in hyperparams.dataset:85pred_logits_sol = pred_logits[:, :, :]86label_one_hot_sol = label_one_hot[:, :, :]87
88ce_loss = optax.softmax_cross_entropy(89logits=pred_logits_sol, labels=label_one_hot_sol90)91mask = np.repeat(92np.arange(len(ce_loss[0])).reshape(1, -1), len(ce_loss), axis=093)94avg_ce_loss = (ce_loss * mask).sum() / mask.sum()95assert avg_ce_loss.ndim == 2, avg_ce_loss.shape96return jnp.mean(avg_ce_loss), pred_logits97elif hyperparams.dataset == "othello":98ce_loss = optax.softmax_cross_entropy(99logits=pred_logits, labels=label_one_hot100)101
102assert ce_loss.ndim == 2, ce_loss.shape103return jnp.mean(ce_loss), pred_logits104
105step = state.step106lr = learning_rate_fn(step)107(loss, pred_logits), grads = jax.value_and_grad(loss_fn,108has_aux=True)(state.params)109grads = jax.lax.pmean(grads, "batch")110new_state = state.apply_gradients(grads=grads)111metrics = {112"step": step, "loss": loss * inputs.shape[0], "learning_rate": lr,113"pred_logits": pred_logits, "weights": inputs.shape[0]114}115return new_state, metrics, pred_logits116
117
118def get_metrics_report_progress(config, workdir, writer):119hooks = []120
121report_progress = periodic_actions.ReportProgress(122num_train_steps=config.max_steps, writer=writer)123
124if jax.process_index() == 0:125hooks += [report_progress,126periodic_actions.Profile(logdir=workdir, num_profile_steps=5)]127train_metrics = []128return hooks, report_progress, train_metrics129
130
131def get_state(config, net, initial_variables):132"""Get the train state given an experiment config, a model and initial variables."""133lr_scheduler_fn = functools.partial(134lr_scheduler,135learning_rate=config.learning_rate,136warmup_tokens=config.warmup_tokens,137final_tokens=config.max_steps,138config=config,139)140optim_fn = None141if config.optimizer == "adamw":142optim_fn = optax.adamw(143lr_scheduler_fn, weight_decay=config.weight_decay, b1=0.9, b2=0.95144)145elif config.optimizer == "lion":146optim_fn = optax.lion(lr_scheduler_fn, weight_decay=config.weight_decay)147
148optimizer = optax.chain(optax.clip_by_global_norm(1), optim_fn)149
150state = train_state.TrainState.create(151apply_fn=net.apply, params=initial_variables["params"],152tx=optimizer153)154
155return state, lr_scheduler_fn156
157
158def train_one_step(159p_train_step,160config,161state,162step,163dropout_rngs,164train_data_iter,165):166"""Single step of the training loop."""167with jax.profiler.StepTraceAnnotation("train", step_num=step):168
169batch = next(train_data_iter)170inputs = None171start_index = None172if "sudoku" in config.dataset:173inputs = common_utils.shard(jax.tree_util.tree_map(np.asarray, batch[0]))174if "dependent" in config.start_index:175start_index = common_utils.shard(176jax.tree_util.tree_map(np.asarray, batch[2])177)178else:179start_index = np.ones(len(batch[2])) * config.start_index180start_index = common_utils.shard(181jax.tree_util.tree_map(np.asarray, start_index)182)183
184elif config.dataset == "othello":185inputs = common_utils.shard(jax.tree_util.tree_map(np.asarray, batch))186
187state, metrics, _ = p_train_step(188state, inputs, start_index, dropout_rng=dropout_rngs189)190
191return state, metrics192