google-research

Форк
0
432 строки · 16.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Layout Base Trainer."""
17

18
import abc
19
import functools
20
import os
21
from typing import Dict
22

23
from absl import logging
24
from clu import metric_writers
25
from clu import periodic_actions
26
import flax
27
from flax import linen as nn
28
import flax.jax_utils as flax_utils
29
from flax.training import common_utils
30
import jax
31
import jax.numpy as jnp
32
from layout_blt import input_pipeline
33
from layout_blt.utils import metrics
34
from layout_blt.utils import task_manager
35
import ml_collections
36
import numpy as np
37
import tensorflow as tf
38

39

40
class LayoutBaseTrainer(abc.ABC):
41
  """Base Trainer for layout generation."""
42

43
  def __init__(self, config, workdir):
44
    self.config = config
45
    self.workdir = workdir
46
    self.rng = jax.random.PRNGKey(config.seed)
47
    self.dtype, self.data_dtype = self.get_dtype()
48
    self.layout_dim = self.config.layout_dim
49
    self.total_dim = self.layout_dim * 2 + 1
50

51
  def get_dtype(self):
52
    if self.config.dtype == "bfloat16":
53
      return jnp.bfloat16, tf.bfloat16
54
    else:
55
      return jnp.float32, tf.float32
56

57
  def merge_model_state(self, state):
58
    if jax.tree_leaves(state.model_state):
59
      cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "x"), "x")
60
      return state.replace(model_state=cross_replica_mean(state.model_state))
61
    else:
62
      return state
63

64
  @abc.abstractmethod
65
  def create_train_state(
66
      self,
67
      rng,
68
      inputs,
69
  ):
70
    pass
71

72
  def create_optimizer(self):
73
    """Creates optimizers separated for weights using decay.
74

75
    Returns:
76
      An instance of `Optimizer`.
77
    """
78
    opt_def = flax.optim.Adam(
79
        learning_rate=self.config.optimizer.lr,
80
        beta1=self.config.optimizer.beta1,
81
        beta2=self.config.optimizer.beta2,
82
        weight_decay=self.config.optimizer.weight_decay)
83
    return opt_def
84

85
  def make_mask(self, vocab_size, pos_info, seq_len, layout_dim):
86
    """Creates masking for logits at each training step and token offset.
87

88
    Our vocabulary is the combination of special symbols, asset classes,
89
    positions and sizes. At each step, only a part of vocabulary is possible.
90
    For example, the first step is asset type and only the asset candidates
91
    could be generated.
92

93
    Args:
94
      vocab_size: vocabulary size.
95
      pos_info: start indexs and number of candidates for each vocabulary
96
        segment. For example, in the following sample, [[2, 3], [6, 2]],
97
        denotes that for the first segment, its start index in the vocab is 2
98
        and there are 3 elements in the first segment.
99
      seq_len: the total length of input sequence.
100
      layout_dim: the layout dimension.
101
    Returns:
102
      asset_logit_masks: [1, seq_len, vocab_size]: logits mask for each step.
103
      asset_offset:  [1, seq_len]: offset to map the output token ids back to
104
        its original ids.
105
    """
106
    total_dim = layout_dim * 2 + 1
107
    logit_masks = []
108
    offset = jnp.array([pi[0] for pi in pos_info])
109
    offset = jnp.expand_dims(offset, 0)
110

111
    asset_offset = jnp.tile(offset, (1, seq_len // total_dim))
112
    # In our current model, the order of asset reprentation is
113
    # [asset, width, height, x, y]. We create masks for each of them.
114
    for idx, pi in enumerate(pos_info):
115
      # Creates a logit mask for the current segment. The logit shape from model
116
      # is [batch size, seq_len, vocab_size]. At a given step, the logit shape
117
      # is [batch size, 1, vocab_size], since for a given position, all logits
118
      # in the batch has the same masking, we just create a [1, 1, vocab_size]
119
      # mask which can broadcast to the whole batch.
120
      logit_mask = jnp.ones((1, 1, vocab_size))
121
      # pi[1] denotes the number of elements in the current segment.
122
      # The shape of pos_mask is [1, 1, #elements in this current segument].
123
      # For example, we have four possible asset classes, the pos_mask should be
124
      # [1, 1, 4].
125
      pos_mask = jnp.zeros((1, 1, pi[1]))
126
      # pi[0] means the start index of the current segment in the vocabulary.
127
      # Here, we update index [pi[0]: pi[0] + pi[1]] to zero.
128
      logit_mask = jax.lax.dynamic_update_slice(logit_mask, pos_mask,
129
                                                (0, 0, pi[0]))
130
      # At asset positions, we could also produce eos symbol.
131
      if idx == 0:
132
        logit_mask = logit_mask.at[:, :, 2].set(0)
133
      logit_masks.append(logit_mask)
134
    # We have creates masks for all segments and concatenate them into the mask
135
    # for a asset. [1, 5, vocab_size]
136
    logit_masks = jnp.concatenate(logit_masks, axis=1)
137
    # We extend the above mask to all positions in the sequences.
138
    asset_logit_masks = jnp.tile(logit_masks, (1, seq_len // total_dim, 1))
139
    # Concatenates all others positions.
140
    if seq_len % total_dim > 0:
141
      asset_logit_masks = jnp.concatenate(
142
          (asset_logit_masks, logit_masks[:, :(seq_len % total_dim), :]),
143
          axis=1)
144
      asset_offset = jnp.concatenate(
145
          (asset_offset, offset[:, :(seq_len % total_dim)]), axis=1)
146

147
    return asset_logit_masks, asset_offset
148

149
  def create_learning_rate_scheduler(self, learning_rate=1., warmup_steps=4000):
150
    """Creates learning rate scheduler for transformer.
151

152
    First increases the learning rate linearly for the first warmup_steps
153
    training steps, and decreases it thereafter proportionally to the inverse
154
    square root of the step number.
155
    Args:
156
      learning_rate: float, the starting constant for the lr schedule.
157
      warmup_steps: int, how many steps to warm up (> 0).
158
    Returns:
159
      A function to calculate the learing rate given current step.
160
    """
161
    def step_fn(step):
162
      cur_lr = learning_rate * jnp.minimum(1.0, step / warmup_steps) / jnp.sqrt(
163
          jnp.maximum(step, warmup_steps))
164
      return jnp.asarray(cur_lr, dtype=jnp.float32)
165

166
    return step_fn
167

168
  def compute_weighted_cross_entropy(self,
169
                                     logits,
170
                                     targets,
171
                                     mask=None,
172
                                     label_smoothing=0.0,
173
                                     logits_mask=None):
174
    """Computes weighted cross entropy between logits and targets.
175

176
    Args:
177
     logits: [batch, length, vocab_size] float array.
178
     targets: [batch, length] int array.
179
     mask: None or array of shape [batch, length].
180
     label_smoothing: label smoothing constant, used to determine the on and off
181
       values.
182
     logits_mask: [batch, length, vocab_size] float array: logits masking to
183
       ignore impossible candidates at each step.
184

185
    Returns:
186
      loss: float scalar.
187
    """
188
    if logits_mask is not None:
189
      logits = jnp.where(logits_mask > 0, -1e7, logits)
190
    if logits.ndim != targets.ndim + 1:
191
      raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
192
                       (str(logits.shape), str(targets.shape)))
193
    vocab_size = logits.shape[-1]
194
    confidence = 1.0 - label_smoothing
195
    low_confidence = label_smoothing / (vocab_size - 1)
196
    soft_targets = common_utils.onehot(
197
        targets, vocab_size, on_value=confidence, off_value=low_confidence)
198

199
    loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
200
    # Calculates the best (lowest) possible value of cross entropy, and
201
    # subtract from the cross entropy loss.
202
    normalizing_constant = -(
203
        confidence * jnp.log(confidence) +
204
        (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
205
    loss = loss - normalizing_constant
206

207
    normalizing_factor = np.prod(targets.shape)
208
    if mask is not None:
209
      loss = loss * mask
210
      normalizing_factor = mask.sum()
211

212
    return loss.sum(), normalizing_factor
213

214
  def evaluate(self,
215
               p_eval_step,
216
               state,
217
               rng,
218
               eval_ds,
219
               batch_size=0,
220
               use_vertical=False,
221
               num_eval_steps=None,
222
               dataset="RICO"):
223
    """Evaluate the target an return a dictionary with the metrics."""
224
    logging.info("Gathering evaluation metrics.")
225
    eval_metrics = None
226
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
227
    for _, eval_batch in zip(range(num_eval_steps), eval_iter):
228
      eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
229
      eval_batch, eval_label = self.preprocess_batch(eval_batch, batch_size,
230
                                                     dataset, use_vertical)
231
      if eval_batch is None:
232
        continue
233
      # eval_batch = common_utils.shard(eval_batch)
234
      metrics_update = p_eval_step(rng, state, eval_batch, eval_label)
235
      metrics_update = flax.jax_utils.unreplicate(metrics_update)
236
      eval_metrics = (
237
          metrics_update
238
          if eval_metrics is None else eval_metrics.merge(metrics_update))
239
    return eval_metrics
240

241
  def train(self):
242
    """Training loop."""
243

244
    tf.io.gfile.makedirs(self.workdir)
245
    checkpoint_dir = os.path.join(self.workdir, "checkpoints")
246
    n_devices = jax.local_device_count()
247
    task_manager_csv = task_manager.TaskManagerWithCsvResults(checkpoint_dir)
248
    rng, data_rng = jax.random.split(self.rng)
249
    # Make sure each host uses a different RNG for the training data.
250
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
251
    batch_size = self.config.batch_size
252
    dataset = self.config.dataset
253
    use_vertical = self.config.use_vertical_info
254

255
    train_ds, eval_ds, _, vocab_size, pos_info = input_pipeline.get_all_dataset(
256
        batch_size,
257
        self.config.dataset_path,
258
        n_devices,
259
        add_bos=self.config.autoregressive,
260
        max_length=self.config.max_length,
261
        dataset_name=dataset)
262
    train_ds = train_ds.repeat()
263
    self.config.vocab_size = vocab_size
264
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
265
    num_train_steps = self.config.num_train_steps
266
    logging.info("total_steps=%d", num_train_steps)
267

268
    rng, model_rng = jax.random.split(rng)
269
    init_batch = jnp.ones(
270
        (batch_size, self.config.max_length))
271
    init_label = jnp.ones((batch_size, 1))
272
    init_batch = dict(inputs=init_batch, labels=init_label)
273
    model_dict, state = self.create_train_state(model_rng, init_batch)
274
    learning_rate_fn = self.create_learning_rate_scheduler(
275
        learning_rate=self.config.optimizer.lr,
276
        warmup_steps=self.config.optimizer.warmup_steps)
277
    state = task_manager.restore_checkpoint(state, checkpoint_dir)
278
    # Creates logits mask.
279
    logits_mask, _ = self.make_mask(vocab_size, pos_info,
280
                                    self.config.max_length,
281
                                    self.config.layout_dim)
282
    initial_step = int(state.step) + 1
283
    # Warm-start from a checkpoint
284
    if initial_step == 1 and self.config.get(
285
        "checkpoint_path") and self.config.checkpoint_path:
286
      state = task_manager.restore_from_path(
287
          state, self.config.checkpoint_path)
288

289
    state = flax_utils.replicate(state)
290

291
    writer = metric_writers.create_default_writer(
292
        self.workdir, just_logging=jax.process_index() > 0)
293
    if initial_step == 1:
294

295
      writer.write_hparams(dict(self.config))
296

297
    logging.info("Starting training loop at step %d.", initial_step)
298
    hooks = []
299
    report_progress = periodic_actions.ReportProgress(
300
        num_train_steps=self.config.num_train_steps, writer=writer)
301
    if jax.process_index() == 0:
302
      hooks += [
303
          report_progress,
304
          periodic_actions.Profile(num_profile_steps=5, logdir=self.workdir)
305
      ]
306

307
    p_train_step = jax.pmap(
308
        functools.partial(
309
            self.train_step,
310
            model_dict=model_dict,
311
            learning_rate_fn=learning_rate_fn,
312
            logits_mask=logits_mask
313
        ),
314
        axis_name="batch")
315

316
    p_eval_step = jax.pmap(
317
        functools.partial(
318
            self.eval_step,
319
            model_dict=model_dict,
320
            logits_mask=logits_mask
321
        ),
322
        axis_name="batch")
323

324
    train_metrics = None
325
    rng = jax.random.fold_in(rng, jax.process_index())
326
    rng, train_rng, sample_rng = jax.random.split(rng, 3)  # pylint: disable=unused-variable
327

328
    with metric_writers.ensure_flushes(writer):
329
      for step in range(initial_step, num_train_steps + 1):
330
        # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
331
        # devices.
332
        is_last_step = step == self.config.num_train_steps
333
        with jax.profiler.StepTraceContext("train", step_num=step):
334
          batch = jax.tree_map(np.asarray, next(train_iter))
335
          batch, label = self.preprocess_batch(batch, batch_size, dataset,
336
                                               use_vertical)
337
          if batch is None:
338
            continue
339

340
          step_rng = jax.random.fold_in(train_rng, step)
341
          step_rngs = jax.random.split(step_rng, jax.local_device_count())
342
          state, metrics_update = p_train_step(step_rngs, state, batch, label)
343
          metric_update = flax.jax_utils.unreplicate(metrics_update)
344
          train_metrics = (
345
              metric_update
346
              if train_metrics is None else train_metrics.merge(metric_update))
347

348
        # Quick indication that training is happening.
349
        logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
350
        for h in hooks:
351
          h(step)
352

353
        if step % self.config.log_every_steps == 0:
354
          logging.info("Finish training step %d.", step)
355
          writer.write_scalars(step, train_metrics.compute())
356
          train_metrics = None
357
        if step % self.config.eval_every_steps == 0 or is_last_step:
358
          logging.info("eval step")
359
          state = self.merge_model_state(state)
360
          sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
361
          eval_metrics = self.evaluate(p_eval_step, state, sample_rngs, eval_ds,
362
                                       batch_size,
363
                                       use_vertical,
364
                                       self.config.eval_num_steps,
365
                                       dataset)
366
          if eval_metrics is not None:
367
            writer.write_scalars(step, eval_metrics.compute())
368

369
        if step % self.config.checkpoint_every_steps == 0 or is_last_step:
370
          with report_progress.timed("checkpoint"):
371
            state = self.merge_model_state(state)
372
            task_manager.save_checkpoint(state, checkpoint_dir)
373
      logging.info("Finishing training at step %d", num_train_steps)
374
    if jax.process_index() == 0:
375
      task_manager_csv.mark_training_done()
376

377
  def evaluate_metrics(self,
378
                       generated_samples,
379
                       real_samples,
380
                       eos_id=-2,
381
                       conditional="a+s"):
382
    """Computing metrics."""
383
    def convert_format(layouts, eos_id):
384
      new_layouts = []
385
      for sample in layouts:
386
        sample = np.array(sample)
387
        if np.nonzero(sample == eos_id)[0].shape[0] > 0:
388
          real_len = np.nonzero(sample == eos_id)[0][0]
389
          sample = sample[:real_len]
390
          new_layouts.append(sample.reshape(-1, 5))
391
      return new_layouts
392
    generated_samples = convert_format(generated_samples, eos_id)
393

394
    iou = []
395
    overlap = []
396
    alignment = []
397
    for sample in generated_samples:
398
      iou.append(metrics.get_layout_iou(sample))
399
      overlap.append(metrics.get_overlap_index(sample))
400
      align_loss = metrics.get_alignment_loss(sample)
401
      if align_loss > 0:
402
        alignment.append(align_loss)
403
    def avg(a):
404
      return sum(a)/len(a)
405
    rst = {
406
        "iou": avg(iou),
407
        "overlap": avg(overlap),
408
        "alignment": avg(alignment)
409
    }
410

411
    if conditional != "unconditional":
412
      real_samples = convert_format(real_samples, eos_id)
413
      similarity = metrics.conditional_distance(generated_samples, real_samples,
414
                                                conditional)
415
      rst["similarity"] = similarity
416
    else:
417
      diveristy = metrics.diveristy(generated_samples)
418
      rst["diversity"] = diveristy
419

420
    return rst
421

422
  @abc.abstractmethod
423
  def train_step(self, rng, state, batch, model_dict, logits_mask):
424
    pass
425

426
  @abc.abstractmethod
427
  def eval_step(self, rng, state, batch, model_dict, logits_mask):
428
    pass
429

430
  @abc.abstractmethod
431
  def preprocess_batch(self, batch, batch_size, dataset, use_vertical):
432
    pass
433

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.