google-research

Форк
0
191 строка · 5.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Transformer LM trainer."""
17

18
import functools
19

20
from clu import periodic_actions
21
from flax.training import common_utils
22
from flax.training import train_state
23
import jax
24
from jax import numpy as jnp
25
import numpy as np
26
import optax
27

28
from sudoku_gpt import model
29

30

31
def lr_scheduler(
32
    n_tokens, learning_rate, warmup_tokens, final_tokens, config
33
    ):
34
  """Learning rate scheduler, adapted from Mikhail Grankin."""
35

36
  # Decay the learning rate based on our progress.
37
  progress = (n_tokens - warmup_tokens) / max(
38
      1,
39
      final_tokens - warmup_tokens,
40
  )
41
  lr_mult = jnp.where(
42
      n_tokens < warmup_tokens,
43
      # Linear warmup.
44
      n_tokens / jnp.fmax(1, warmup_tokens),
45
      # Cosine learning rate decay.
46
      jnp.fmax(config.end_lr_factor, 0.5 * (1.0 + jnp.cos(np.pi * progress))),
47
  )
48
  return learning_rate * lr_mult
49

50

51
def train_step(state, batch, config, hyperparams, learning_rate_fn,
52
               dropout_rng=None):
53
  """One step of the training loop.
54

55
  Args:
56
    state: train state.
57
    batch: input batch
58
    config: experiment config
59
    hyperparams: hyperparameter dictionary
60
    learning_rate_fn: learning rate function
61
    dropout_rng: rng to be used for dropout
62

63
  Returns:
64
    A new train state, train metrics and computed model predictions.
65
  """
66

67
  inputs = batch[:, :-1]
68
  label = batch[:, 1:]
69

70
  dropout_rng = jax.random.fold_in(dropout_rng, state.step)
71
  dropout_rng_dict = {"dropout": dropout_rng}
72

73
  def loss_fn(params):
74
    corrupted_inputs = inputs
75
    pred_logits = model.TransformerLMHeadModel(config).apply(
76
        {"params": params}, corrupted_inputs, rngs=dropout_rng_dict)
77

78
    label_one_hot = jax.nn.one_hot(label, num_classes=config.vocab_size)
79

80
    assert label_one_hot.shape == pred_logits.shape, ("one hot label shape",
81
                                                      label_one_hot.shape,
82
                                                      label.shape,
83
                                                      pred_logits.shape)
84
    if "sudoku" in hyperparams.dataset:
85
      pred_logits_sol = pred_logits[:, :, :]
86
      label_one_hot_sol = label_one_hot[:, :, :]
87

88
      ce_loss = optax.softmax_cross_entropy(
89
          logits=pred_logits_sol, labels=label_one_hot_sol
90
      )
91
      mask = np.repeat(
92
          np.arange(len(ce_loss[0])).reshape(1, -1), len(ce_loss), axis=0
93
      )
94
      avg_ce_loss = (ce_loss * mask).sum() / mask.sum()
95
      assert avg_ce_loss.ndim == 2, avg_ce_loss.shape
96
      return jnp.mean(avg_ce_loss), pred_logits
97
    elif hyperparams.dataset == "othello":
98
      ce_loss = optax.softmax_cross_entropy(
99
          logits=pred_logits, labels=label_one_hot
100
          )
101

102
      assert ce_loss.ndim == 2, ce_loss.shape
103
      return jnp.mean(ce_loss), pred_logits
104

105
  step = state.step
106
  lr = learning_rate_fn(step)
107
  (loss, pred_logits), grads = jax.value_and_grad(loss_fn,
108
                                                  has_aux=True)(state.params)
109
  grads = jax.lax.pmean(grads, "batch")
110
  new_state = state.apply_gradients(grads=grads)
111
  metrics = {
112
      "step": step, "loss": loss * inputs.shape[0], "learning_rate": lr,
113
      "pred_logits": pred_logits, "weights": inputs.shape[0]
114
  }
115
  return new_state, metrics, pred_logits
116

117

118
def get_metrics_report_progress(config, workdir, writer):
119
  hooks = []
120

121
  report_progress = periodic_actions.ReportProgress(
122
      num_train_steps=config.max_steps, writer=writer)
123

124
  if jax.process_index() == 0:
125
    hooks += [report_progress,
126
              periodic_actions.Profile(logdir=workdir, num_profile_steps=5)]
127
  train_metrics = []
128
  return hooks, report_progress, train_metrics
129

130

131
def get_state(config, net, initial_variables):
132
  """Get the train state given an experiment config, a model and initial variables."""
133
  lr_scheduler_fn = functools.partial(
134
      lr_scheduler,
135
      learning_rate=config.learning_rate,
136
      warmup_tokens=config.warmup_tokens,
137
      final_tokens=config.max_steps,
138
      config=config,
139
  )
140
  optim_fn = None
141
  if config.optimizer == "adamw":
142
    optim_fn = optax.adamw(
143
        lr_scheduler_fn, weight_decay=config.weight_decay, b1=0.9, b2=0.95
144
    )
145
  elif config.optimizer == "lion":
146
    optim_fn = optax.lion(lr_scheduler_fn, weight_decay=config.weight_decay)
147

148
  optimizer = optax.chain(optax.clip_by_global_norm(1), optim_fn)
149

150
  state = train_state.TrainState.create(
151
      apply_fn=net.apply, params=initial_variables["params"],
152
      tx=optimizer
153
      )
154

155
  return state, lr_scheduler_fn
156

157

158
def train_one_step(
159
    p_train_step,
160
    config,
161
    state,
162
    step,
163
    dropout_rngs,
164
    train_data_iter,
165
):
166
  """Single step of the training loop."""
167
  with jax.profiler.StepTraceAnnotation("train", step_num=step):
168

169
    batch = next(train_data_iter)
170
    inputs = None
171
    start_index = None
172
    if "sudoku" in config.dataset:
173
      inputs = common_utils.shard(jax.tree_util.tree_map(np.asarray, batch[0]))
174
      if "dependent" in config.start_index:
175
        start_index = common_utils.shard(
176
            jax.tree_util.tree_map(np.asarray, batch[2])
177
        )
178
      else:
179
        start_index = np.ones(len(batch[2])) * config.start_index
180
        start_index = common_utils.shard(
181
            jax.tree_util.tree_map(np.asarray, start_index)
182
        )
183

184
    elif config.dataset == "othello":
185
      inputs = common_utils.shard(jax.tree_util.tree_map(np.asarray, batch))
186

187
    state, metrics, _ = p_train_step(
188
        state, inputs, start_index, dropout_rng=dropout_rngs
189
    )
190

191
  return state, metrics
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.