google-research

Форк
0
/
transformer_trainer.py 
374 строки · 13.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Transformer for Layout Trainer."""
17

18
# pylint: disable=g-import-not-at-top
19
# pylint: disable=g-bad-import-order
20

21
import sys
22
sys.path.append("..")
23

24
import functools
25
from typing import Any, Dict, Optional, Tuple
26

27
from absl import logging
28
from clu import metrics
29
from clu import parameter_overview
30
import flax
31
import flax.jax_utils as flax_utils
32
from flax.training import common_utils
33
import jax
34
import jax.numpy as jnp
35
from layout_blt import input_pipeline
36
from layout_blt.nets import transformer
37
from layout_blt.trainers import base_trainer
38
from layout_blt.utils import layout_fast_decode
39
from layout_blt.utils import task_manager
40
import numpy as np
41

42

43
@flax.struct.dataclass
44
class TrainState:
45
  """Data structure for checkpoint the model."""
46
  step: int
47
  optimizer: flax.optim.Optimizer
48
  model_state: Optional[Any]
49

50

51
@flax.struct.dataclass
52
class TrainMetrics(metrics.Collection):
53
  """Metrics during training process."""
54
  loss: metrics.Average.from_output("loss")
55

56

57
@flax.struct.dataclass
58
class EvalMetrics(metrics.Collection):
59
  """Metrics during evaluation process."""
60
  eval_loss: metrics.Average.from_output("eval_loss")
61

62

63
class TransformerTrainer(base_trainer.LayoutBaseTrainer):
64
  """Transformer for Layout Trainer."""
65

66
  def preprocess_batch(self, batch, batch_size, dataset, use_vertical=False):
67
    label = None
68
    # When we reach to the end of the dataset iter, the batch size may not be
69
    # be our expected one. In the case, we simply skip it.
70
    if batch.shape[0] != batch_size:
71
      return None, None
72
    batch = common_utils.shard(batch)
73
    return batch, label
74

75
  def create_train_state(
76
      self,
77
      rng,
78
      inputs,
79
  ):
80
    model = functools.partial(
81
        transformer.TransformerDecoder, config=self.config)
82
    param_rng, latent_rng = jax.random.split(rng, 2)
83
    model_variables = model(deterministic=True).init(param_rng,
84
                                                     inputs["inputs"],
85
                                                     inputs["labels"],
86
                                                     latent_rng)
87

88
    model_state = dict(model_variables)
89
    model_params = model_state.pop("params")
90
    logging.info("logging model parameters")
91
    parameter_overview.log_parameter_overview(model_params)
92
    optimizer = self.create_optimizer().create(model_params)
93
    model_dict = dict(model=model)
94
    train_state = TrainState(
95
        step=0,
96
        optimizer=optimizer,
97
        model_state=model_state)
98
    return model_dict, train_state
99

100
  def train_step(
101
      self,
102
      rng,
103
      state,
104
      batch,
105
      label,
106
      learning_rate_fn,
107
      model_dict,
108
      logits_mask
109
  ):
110
    """Perform a single training step.
111

112
    Args:
113
      rng: The random seed,
114
      state: State of the model (optimizer and state).
115
      batch: Training inputs for this step.
116
      label: Training input vectical info (always None for now).
117
      learning_rate_fn: The learning scheduler.
118
      model_dict: The model used in training.
119
      logits_mask: Logits mask for each step.
120

121
    Returns:
122
      The new model state and dictionary with metrics
123
    """
124
    logging.info("train_step(batch=%s)", batch)
125
    step = state.step + 1
126
    lr = learning_rate_fn(state.step)
127
    model = model_dict["model"]
128
    dec_target = batch[:, 1:]
129
    if logits_mask is not None:
130
      logits_mask = logits_mask[:, :-1, :]
131
    pad_mask = jnp.where(dec_target > 0, 1, 0).astype(jnp.float32)
132
    def loss_fn(params):
133
      dropout_rng, latent_rng = jax.random.split(rng)
134
      variables = {"params": params}
135
      variables.update(state.model_state)
136
      (logits, _), new_variables = model().apply(
137
          {"params": params},
138
          batch,
139
          label,
140
          latent_rng,
141
          rngs={"dropout": dropout_rng},
142
          mutable=True)
143
      recon_loss, num_tokens = self.compute_weighted_cross_entropy(
144
          logits, dec_target, pad_mask, self.config.label_smoothing,
145
          logits_mask)
146
      recon_loss = recon_loss / num_tokens
147
      loss = recon_loss
148
      new_model_state = dict(new_variables)
149
      new_model_state.pop("params")
150
      return loss, (recon_loss, new_model_state)
151

152
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
153
    (loss, (recon_loss,
154
            new_model_state)), grad = grad_fn(state.optimizer.target)
155
    del recon_loss
156
    grad = jax.lax.pmean(grad, "batch")
157
    new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
158
    new_state = state.replace(
159
        step=step, optimizer=new_optimizer, model_state=new_model_state)
160
    metrics_update = TrainMetrics.gather_from_model_output(loss=loss)
161
    return new_state, metrics_update
162

163
  def eval_step(self, rng, state, batch, label, model_dict, logits_mask):
164
    model = model_dict["model"]
165
    dec_target = batch[:, 1:]
166
    logits_mask = logits_mask[:, :-1, :]
167
    pad_mask = jnp.where(dec_target > 0, 1, 0).astype(jnp.float32)
168
    (logits, _) = model(deterministic=True).apply(
169
        {"params": state.optimizer.target},
170
        batch,
171
        label,
172
        rng)
173
    recon_loss, num_tokens = self.compute_weighted_cross_entropy(
174
        logits, dec_target, pad_mask, self.config.label_smoothing, logits_mask)
175
    recon_loss = recon_loss / num_tokens
176
    loss = recon_loss
177
    metrics_update = EvalMetrics.gather_from_model_output(eval_loss=loss)
178
    return metrics_update
179

180
  def test(self,
181
           sampling_method="topp",
182
           conditional="none",
183
           eos_id=2,
184
           batch_size=1,
185
           sample_step_num=1,
186
           max_decode_len=128,
187
           use_vertical=False,
188
           vertical_idx=0):
189
    """Runs a test run.
190

191
    Args:
192
      sampling_method: str: how to generate the current token.
193
      conditional: str: none: uncondtional generation, a: asset condtional
194
        generation, a+s: asset + size condtional generation.
195
      eos_id: int: the index of eos symbol.
196
      batch_size: int: batch size of generation at one time.
197
      sample_step_num: int: how many batches to generate.
198
      max_decode_len: int: the maximum number of tokens during generation.
199
      use_vertical: bool: whether use vertical information (always False).
200
      vertical_idx: int: vertical index.
201
    Returns:
202
      generated_samples: [sample_step_num*batch_size, max_decode_len]:
203
        generated layouts.
204
      real_samples: [sample_step_num*batch_size, max_decode_len]: real layouts.
205
    """
206
    assert batch_size % jax.local_device_count() == 0
207
    rng = jax.random.PRNGKey(self.config.seed)
208
    # Make sure each host uses a different RNG.
209
    rng = jax.random.fold_in(rng, jax.process_index())
210
    rng, model_rng, data_rng = jax.random.split(rng, 3)
211
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
212
    dataset = self.config.dataset
213

214
    test_ds, vocab_size, pos_info = input_pipeline.get_dataset(
215
        batch_size,
216
        self.config.dataset_path,
217
        jax.local_device_count(),
218
        "test.json",
219
        max_length=max_decode_len,
220
        dataset_name=dataset)
221

222
    init_batch = jnp.ones((batch_size, self.config.max_length))
223
    init_label = jnp.ones((batch_size, 1))
224
    init_batch = dict(inputs=init_batch, labels=init_label)
225
    model_dict, state = self.create_train_state(model_rng, init_batch)
226
    ckpt_path = self.config.test_checkpoint_dir
227
    state = task_manager.restore_checkpoint(state, ckpt_path)
228
    state = flax_utils.replicate(state)
229
    sample_one_batch_fn = functools.partial(
230
        self.sample_one_batch,
231
        pos_info=pos_info,
232
        batch_size=batch_size//jax.local_device_count(),
233
        conditional=conditional,
234
        eos_id=eos_id,
235
        max_decode_len=max_decode_len,
236
        sampling_method=sampling_method)
237
    p_generate_batch = jax.pmap(
238
        functools.partial(
239
            sample_one_batch_fn,
240
            model_dict=model_dict,
241
        ),
242
        axis_name="batch")
243

244
    test_iter = iter(test_ds)  # pytype: disable=wrong-arg-types
245
    def tohost(x):
246
      """Collect batches from all devices to host and flatten batch dimensions."""
247
      n_device, n_batch, *remaining_dims = x.shape
248
      return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims))
249

250
    _, sample_offset = self.make_mask(vocab_size, pos_info, max_decode_len,
251
                                      self.layout_dim)
252
    generated_sample_list, real_sample_list = [], []
253
    for idx, test_batch in enumerate(test_iter):
254
      if idx == sample_step_num:
255
        break
256
      rng, sample_rng = jax.random.split(rng, 2)
257
      p_rng = jax.random.split(sample_rng, jax.local_device_count())
258
      test_batch = jax.tree_map(lambda x: x._numpy(), test_batch)  # pylint: disable=protected-access
259
      test_batch, test_label = self.preprocess_batch(test_batch, batch_size,
260
                                                     dataset, use_vertical)
261
      # For uncondtional generation, we stop the process according to the
262
      # sampel_step_num, otherwise, we use the whole test set.
263
      if test_batch is None or (conditional == "none" and
264
                                idx == sample_step_num):
265
        break
266

267
      if conditional == "none":
268
        if use_vertical:
269
          test_label = jnp.full_like(test_label, vertical_idx)
270
        sample_layouts = p_generate_batch(None, p_rng, state, label=test_label)
271
      else:
272
        sample_layouts = p_generate_batch(test_batch[Ellipsis, 1:], p_rng, state,
273
                                          label=test_label)
274
      # We do not need bos symbol.
275
      sample_layouts = tohost(sample_layouts)[Ellipsis, 1:]
276
      real_layouts = None
277
      if test_batch is not None:
278
        real_layouts = tohost(test_batch)[Ellipsis, 1:]
279
        _, real_offset = self.make_mask(self.config.vocab_size, pos_info,
280
                                        real_layouts.shape[-1], self.layout_dim)
281
        real_layouts = real_layouts - real_offset
282
      generated_sample_list.append(sample_layouts - sample_offset[Ellipsis, :-1])
283
      real_sample_list.append(real_layouts)
284
    generated_samples = jnp.concatenate(generated_sample_list, axis=0)
285
    real_samples = jnp.concatenate(real_sample_list, axis=0)
286

287
    return generated_samples, real_samples
288

289
  def sample_step(self, rng, state, model_dict, pos_info):
290
    """Samples layouts just for visualization during training."""
291
    # TODO(xiang): sample some images during training.
292
    return None
293

294
  def fast_decode(self,
295
                  rng,
296
                  variables,
297
                  model,
298
                  pos_info,
299
                  label=None,
300
                  batch=None,
301
                  batch_size=1,
302
                  conditional="none",
303
                  eos_id=2,
304
                  max_decode_len=100,
305
                  sampling_method="topp"):
306
    """Fast layout generation deocoding method.
307

308
    Args:
309
      rng: jax random state.
310
      variables: model parameters.
311
      model: layouu generation model.
312
      pos_info: vocabulary segmentation infomation.
313
      label: vertical information (always None for now).
314
      batch: real layouts batch for conditional generation.
315
      batch_size: number of layouts to generate one time.
316
      conditional: conditioanl type.
317
      eos_id: index of eos symbol.
318
      max_decode_len: maximum number of tokens to generate.
319
      sampling_method: sampling method during generation (argmax or sampling).
320
    Returns:
321
      seqs: generated layouts.
322
    """
323
    eval_model = model(deterministic=True, is_train=False)
324
    init_rng, rng, latent_rng = jax.random.split(rng, 3)
325
    init_batch = jnp.ones((batch_size, max_decode_len))
326
    init_label = jnp.ones((batch_size, 1))
327
    initial_vars = eval_model.init(init_rng, init_batch, init_label, latent_rng)
328
    cache_dict, _ = initial_vars.pop("params")
329

330
    def tokens_to_logits(xi, cache_dict, decode_step, initial_z):
331
      logits, cache_dict = eval_model.apply(
332
          {
333
              **variables,
334
              **cache_dict
335
          },
336
          xi,
337
          label,
338
          initial_z,
339
          decode_step,
340
          mutable=["cache"],
341
          method=transformer.TransformerDecoder.decode)
342
      return logits, cache_dict
343

344
    logit_masks, _ = self.make_mask(self.config.vocab_size, pos_info,
345
                                    self.total_dim, self.layout_dim)
346
    # BOS symbol.
347
    initial_z = jax.random.normal(rng, (batch_size, eval_model.config.emb_dim))
348
    tokens_to_logits_fn = functools.partial(
349
        tokens_to_logits, initial_z=initial_z)
350
    batch = init_batch if batch is None else batch
351

352
    seqs = layout_fast_decode.decode(
353
        batch,
354
        cache_dict,
355
        tokens_to_logits_fn,
356
        max_decode_len=max_decode_len,
357
        sampling_method=sampling_method,
358
        rng=rng,
359
        logit_masks=logit_masks,
360
        conditional=conditional)
361
    return seqs
362

363
  def sample_one_batch(self, batch, rng, state, model_dict, pos_info, label,
364
                       batch_size, conditional, eos_id, max_decode_len,
365
                       sampling_method):
366
    """Samples one batch for eval."""
367
    model = model_dict["model"]
368
    variables = {"params": state.optimizer.target}
369
    variables.update(state.model_state)
370

371
    x = self.fast_decode(rng, variables, model, pos_info, label, batch,
372
                         batch_size, conditional, eos_id, max_decode_len,
373
                         sampling_method)
374
    return x
375

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.