google-research

Форк
0
/
bert_layout_trainer.py 
513 строк · 19.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""BERT Layout Trainer."""
17

18
#  pylint: disable=g-import-not-at-top
19
import functools
20
import sys
21
sys.path.append("..")
22

23
import time
24
from typing import Any, Dict, Optional
25

26
from absl import logging
27
from clu import metrics
28
from clu import parameter_overview
29
import flax
30
import flax.jax_utils as flax_utils
31
from flax.training import common_utils
32
import jax
33
import jax.numpy as jnp
34
from layout_blt import input_pipeline
35
from layout_blt.nets import na_layout_net
36
from layout_blt.trainers import base_trainer
37
from layout_blt.utils import layout_bert_fast_decode
38
from layout_blt.utils import task_manager
39
import numpy as np
40

41

42
@flax.struct.dataclass
43
class TrainState:
44
  """Data structure for checkpoint the model."""
45
  step: int
46
  optimizer: flax.optim.Optimizer
47
  model_state: Optional[Any]
48

49

50
@flax.struct.dataclass
51
class TrainMetrics(metrics.Collection):
52
  """Metrics during training process."""
53
  loss: metrics.Average.from_output("loss")
54

55

56
@flax.struct.dataclass
57
class EvalMetrics(metrics.Collection):
58
  """Metrics during evaluation process."""
59
  eval_loss: metrics.Average.from_output("eval_loss")
60

61

62
def rmlm_masking(inputs, mask_token, pad_token):
63
  """Random length masking function.
64

65
  Different from standard BERT masking which has a fixed mask ratio. We follow
66
  masking process in mask predict (https://arxiv.org/abs/1904.09324). A random
67
  mask ratio between [0, 1) are sampled first and input sequence tokens will
68
  be masked based this mask ratio.
69

70
  Args:
71
    inputs: input layout sequences.
72
    mask_token: the index of mask token.
73
    pad_token: the index of pad token.
74
  Returns:
75
    dictionary of masked input, original input and mask weights.
76
  """
77
  targets = inputs
78

79
  rng = jax.random.PRNGKey(jnp.sum(inputs, dtype="int32"))
80

81
  # Gets positions to leave untouched.
82
  is_pad = inputs == pad_token
83
  lens = jnp.sum(~is_pad, axis=-1)
84
  # Random samples a mask ratio.
85
  mask_rate = 1 - jax.random.uniform(rng, lens.shape)
86
  # Obtains the ceiling of the lens to make sure we can mask at least one token.
87
  mask_lens = jax.lax.ceil(lens * mask_rate)
88
  # Positions to mask.
89
  rng, subrng = jax.random.split(rng)
90
  # Randomly generates the mask score uniformly.
91
  should_mask = jax.random.uniform(subrng, shape=inputs.shape)
92
  # Doesn't mask out padding.
93
  should_mask = jnp.where(is_pad, 2., should_mask)
94
  # should_mask = jnp.where(is_pad | (~target_mask), 2., should_mask)
95

96
  sorted_should_mask = jnp.sort(should_mask, axis=-1)
97

98
  # Obtains the cutoff score for the mask lens.
99
  cut_off = jnp.take_along_axis(
100
      sorted_should_mask, jnp.expand_dims(mask_lens-1, 1), axis=-1)
101
  cut_off = jnp.repeat(cut_off, inputs.shape[1], axis=1)
102

103
  # Scores smaller than the cutoff will be masked.
104
  should_mask = jnp.where(should_mask <= cut_off, 1., 0.)
105

106
  # Full array of MASK tokens
107
  fullmask = jnp.full_like(inputs, mask_token)
108

109
  # Only replace positions where `should_mask`
110
  masked_inputs = jnp.where(should_mask, fullmask, inputs)
111
  weights = should_mask
112
  return dict(masked_inputs=masked_inputs, targets=targets, weights=weights)
113

114

115
def attribute_random_masking(inputs, mask_token, pad_token, layout_dim):
116
  """Attribute-wise masking process..
117

118
  Different from standard BERT masking which has a fixed mask ratio. Each time,
119
  we only mask one of three attributes (category, size and position), then a
120
  random mask ratio between [0, 1) are sampled and this attirbute position
121
  tokens will be masked based this mask ratio.
122

123
  Args:
124
    inputs: input layout sequences.
125
    mask_token: the index of mask token.
126
    pad_token: the index of pad token.
127
    layout_dim: the dimension of layout.
128
  Returns:
129
    dictionary of masked input, original input and mask weights.
130
  """
131
  targets = inputs
132
  total_dim = layout_dim * 2 + 1
133

134
  rng = jax.random.PRNGKey(jnp.sum(inputs, dtype="int32"))
135

136
  # Gets positions to leave untouched.
137
  is_pad = inputs == pad_token
138
  position_ids = jnp.arange(inputs.shape[-1])[None, :]
139
  is_asset = position_ids % total_dim == 0
140
  # is_size = (position_ids % 5 == 1) | (position_ids % 5 == 2)
141
  # is_position = (position_ids % 5 == 3) | (position_ids % 5 == 4)
142
  is_size = functools.reduce(
143
      lambda x, y: x | y,
144
      [position_ids % total_dim == i for i in range(1, layout_dim + 1)])
145
  is_position = functools.reduce(
146
      lambda x, y: x | y,
147
      [position_ids % total_dim == i for i in range(layout_dim + 1, total_dim)])
148
  # three steps masking
149
  rand = jax.random.uniform(rng, (inputs.shape[0], 1))
150

151
  target_mask = (~is_pad) & is_asset
152
  target_mask = jnp.where(
153
      jnp.logical_and(rand >= 0.2, rand < 0.4),
154
      (is_asset | is_size) & (~is_pad), target_mask)
155
  # target_mask = jnp.where(
156
  #     jnp.logical_and(rand >= 0.4, rand < 0.6),
157
  #     (is_asset | is_position) & (~is_pad), target_mask)
158
  target_mask = jnp.where(rand >= 0.4, ~is_pad, target_mask)
159
  should_mask = target_mask
160

161
  # Full array of MASK tokens
162
  fullmask = jnp.full_like(inputs, mask_token)
163
  fullmask = jnp.where(is_pad, pad_token, fullmask)
164

165
  # Only replace positions where `should_mask`
166
  pre_masked_inputs = jnp.where(should_mask, inputs, fullmask)
167
  weights = is_asset & (~is_pad)
168
  weights = jnp.where(
169
      jnp.logical_and(rand >= 0.2, rand < 0.4), is_size & (~is_pad), weights)
170
  weights = jnp.where(
171
      jnp.logical_and(rand >= 0.4, rand < 0.6), is_position & (~is_pad),
172
      weights)
173
  weights = jnp.where(
174
      jnp.logical_and(rand >= 0.6, rand < 0.8), is_size & (~is_pad), weights)
175
  weights = jnp.where(rand >= 0.8, is_asset & (~is_pad), weights)
176

177
  # lens = jnp.sum(target_mask & (~is_pad), axis=-1)
178
  lens = jnp.sum(weights, axis=-1)
179
  rng, subrng = jax.random.split(rng)
180
  mask_rate = 1 - jax.random.uniform(subrng, lens.shape)
181

182
  # Obtains the ceiling of the lens to make sure we can mask at least one token.
183
  mask_lens = jax.lax.ceil(lens * mask_rate)
184
  # Positions to mask.
185
  rng, subrng = jax.random.split(rng)
186
  # Randomly generates the mask score uniformly.
187
  should_mask = jax.random.uniform(subrng, shape=inputs.shape)
188
  # Doesn't mask out padding.
189
  should_mask = jnp.where(weights, should_mask, 2.)
190

191
  sorted_should_mask = jnp.sort(should_mask, axis=-1)
192

193
  # Obtains the cutoff score for the mask lens.
194
  cut_off = jnp.take_along_axis(
195
      sorted_should_mask, jnp.expand_dims(mask_lens-1, 1), axis=-1)
196
  cut_off = jnp.repeat(cut_off, inputs.shape[1], axis=1)
197

198
  # Scores smaller than the cutoff will be masked.
199
  should_mask = jnp.where(should_mask <= cut_off, 1., 0.)
200

201
  # Full array of MASK tokens
202
  fullmask = jnp.full_like(inputs, mask_token)
203

204
  # Only replace positions where `should_mask`
205
  masked_inputs = jnp.where(should_mask, fullmask, pre_masked_inputs)
206
  weights = jnp.where(is_pad, 0, should_mask)
207
  return dict(masked_inputs=masked_inputs, targets=targets, weights=weights)
208

209

210
class BERTLayoutTrainer(base_trainer.LayoutBaseTrainer):
211
  """BERT-style Layout Trainer."""
212

213
  def preprocess_batch(self, batch, batch_size, dataset, use_vertical=False):
214
    label = None
215
    # When we reach to the end of the dataset iter, the batch size may not be
216
    # be our expected one. In the case, we simply skip it.
217
    if batch.shape[0] != batch_size:
218
      return None, None
219
    batch = attribute_random_masking(
220
        batch, mask_token=3, pad_token=0, layout_dim=self.layout_dim)
221
    batch = common_utils.shard(batch)
222
    return batch, label
223

224
  def create_train_state(
225
      self,
226
      rng,
227
      inputs,
228
  ):
229
    model = functools.partial(
230
        na_layout_net.NALayoutNet,
231
        use_vertical=self.config.use_vertical_info,
232
        vocab_size=self.config.vocab_size,
233
        hidden_size=self.config.qkv_dim,
234
        num_hidden_layers=self.config.num_layers,
235
        num_attention_heads=self.config.num_heads,
236
        intermediate_size=self.config.mlp_dim,
237
        pad_token_id=0,
238
        layout_dim=self.layout_dim)
239
    param_rng, dropout_rng, rng = jax.random.split(rng, 3)
240
    model_variables = model().init({
241
        "params": param_rng,
242
        "dropout": dropout_rng
243
    },
244
                                   inputs["inputs"],
245
                                   inputs["labels"],
246
                                   deterministic=False)
247
    model_state = dict(model_variables)
248
    model_params = model_state.pop("params")
249
    logging.info("logging model parameters")
250
    parameter_overview.log_parameter_overview(model_params)
251
    optimizer = self.create_optimizer().create(model_params)
252
    model_dict = dict(model=model)
253
    train_state = TrainState(
254
        step=0,
255
        optimizer=optimizer,
256
        model_state=model_state)
257
    return model_dict, train_state
258

259
  def train_step(self, rng, state, batch, label, learning_rate_fn, model_dict,
260
                 logits_mask):
261
    """Perform a single training step.
262

263
    Args:
264
      rng: The random seed,
265
      state: State of the model (optimizer and state).
266
      batch: Training inputs for this step.
267
      label: Training input vectical info (always None for now).
268
      learning_rate_fn: The learning scheduler.
269
      model_dict: The model used in training.
270
      logits_mask: Logits mask for each step.
271

272
    Returns:
273
      The new model state and dictionary with metrics
274
    """
275
    logging.info("train_step(batch=%s)", batch)
276
    step = state.step + 1
277
    lr = learning_rate_fn(state.step)
278
    model = model_dict["model"]
279

280
    def loss_fn(params):
281
      variables = {"params": params}
282
      variables.update(state.model_state)
283

284
      logits, new_variables = model().apply(
285
          {"params": params},
286
          batch["masked_inputs"],
287
          labels=label,
288
          deterministic=False,
289
          rngs={"dropout": rng},
290
          mutable=True)
291
      ce_loss, num_tokens = self.compute_weighted_cross_entropy(
292
          logits, batch["targets"], batch["weights"],
293
          self.config.label_smoothing, logits_mask)
294

295
      loss = ce_loss / num_tokens
296
      new_model_state = dict(new_variables)
297
      new_model_state.pop("params")
298
      return loss, new_model_state
299

300
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
301
    (loss, new_model_state), grad = grad_fn(state.optimizer.target)
302
    grad = jax.lax.pmean(grad, "batch")
303
    new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
304
    new_state = state.replace(
305
        step=step, optimizer=new_optimizer, model_state=new_model_state)
306
    metrics_update = TrainMetrics.gather_from_model_output(loss=loss)
307
    return new_state, metrics_update
308

309
  def eval_step(self, rng, state, batch, label, model_dict, logits_mask):
310
    model = model_dict["model"]
311
    logits = model().apply(
312
        {"params": state.optimizer.target},
313
        batch["masked_inputs"],
314
        labels=label,
315
        deterministic=True)
316
    ce_loss, num_tokens = self.compute_weighted_cross_entropy(
317
        logits, batch["targets"], batch["weights"], self.config.label_smoothing,
318
        logits_mask)
319

320
    loss = ce_loss / num_tokens
321
    metrics_update = EvalMetrics.gather_from_model_output(eval_loss=loss)
322
    return metrics_update
323

324
  def test(self,
325
           batch_size=1,
326
           iterative_nums=None,
327
           conditional="none",
328
           max_decode_len=128,
329
           use_vertical=False,
330
           sample_step_num=10,
331
           prior=None,
332
           max_asset_num=22,
333
           vertical_idx=0):
334
    """Runs a test run."""
335
    rng = jax.random.PRNGKey(self.config.seed)
336
    np.random.seed(self.config.seed)
337
    # Make sure each host uses a different RNG.
338
    rng = jax.random.fold_in(rng, jax.process_index())
339
    rng, model_rng, data_rng = jax.random.split(rng, 3)
340
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
341
    dataset = self.config.dataset
342
    test_ds, vocab_size, pos_info = input_pipeline.get_dataset(
343
        batch_size,
344
        self.config.dataset_path,
345
        jax.local_device_count(),
346
        "test.json",
347
        max_decode_len,
348
        add_bos=False,
349
        dataset_name=dataset)
350
    logits_mask, offset = self.make_mask(vocab_size, pos_info,
351
                                         max_decode_len,
352
                                         self.layout_dim)
353
    init_batch = jnp.ones(
354
        (batch_size, max_decode_len))
355
    init_label = jnp.ones((batch_size, 1))
356
    init_batch = dict(inputs=init_batch, labels=init_label)
357
    model_dict, state = self.create_train_state(model_rng, init_batch)
358
    ckpt_path = self.config.test_checkpoint_dir
359
    state = task_manager.restore_checkpoint(state, ckpt_path)
360
    state = flax_utils.replicate(state)
361

362
    sample_one_batch_fn = functools.partial(
363
        self.sample_one_batch,
364
        pos_info=pos_info,
365
        iterative_num=iterative_nums,
366
        conditional=conditional,
367
        logits_mask=logits_mask)
368
    p_generate_batch = jax.pmap(
369
        functools.partial(
370
            sample_one_batch_fn,
371
            model_dict=model_dict,
372
        ),
373
        axis_name="batch")
374
    test_iter = iter(test_ds)  # pytype: disable=wrong-arg-types
375
    generated_sample_list, real_sample_list = [], []
376
    assert iterative_nums is not None and len(iterative_nums) == 3
377
    iterative_nums = np.array(iterative_nums)
378
    def tohost(x):
379
      """Collect batches from all devices to host and flatten batch dimensions."""
380
      n_device, n_batch, *remaining_dims = x.shape
381
      return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims))
382

383
    if conditional == "none":
384
      total_time = 0.
385
      for idx in range(sample_step_num):
386
        if use_vertical:
387
          test_label = jnp.full((batch_size, 1), vertical_idx)
388
        else:
389
          test_label = None
390
        asset_num = np.random.choice(max_asset_num, batch_size, p=prior) + 1
391
        rng, sample_rng = jax.random.split(rng, 2)
392
        p_rng = jax.random.split(sample_rng, jax.local_device_count())
393

394
        # All mask symbols.
395
        asset_num = jnp.array(asset_num, dtype="int32")[Ellipsis, None]
396
        # element_num = asset_num * 5
397
        element_num = asset_num * self.total_dim
398
        masked_batch = jnp.full((batch_size, 128), 3)
399

400
        position_ids = jnp.arange(masked_batch.shape[-1])[None, :]
401
        # Pads masked batch.
402
        masked_batch = jnp.where(position_ids >= element_num, 0, masked_batch)
403
        masked_batch = common_utils.shard(masked_batch)
404
        test_label = common_utils.shard(test_label)
405

406
        p_rng = jax.random.split(rng, jax.local_device_count())
407
        start_time = time.time()
408

409
        sample_layouts = p_generate_batch(masked_batch, test_label, p_rng,
410
                                          state)
411
        total_time += time.time() - start_time
412
        start_time = time.time()
413
        sample_layouts = tohost(sample_layouts)
414
        generated_sample_list.append(sample_layouts - offset)
415

416
      generated_samples = jnp.concatenate(generated_sample_list, axis=0)
417
      real_samples = None
418
      logging.info("decoding time: (%.4f)", total_time)
419
      return generated_samples, real_samples
420

421
    for idx, test_batch in enumerate(test_iter):
422
      if idx >= sample_step_num:
423
        break
424
      asset_num = np.random.choice(max_asset_num, batch_size, p=prior) + 1
425
      rng, sample_rng = jax.random.split(rng, 2)
426
      p_rng = jax.random.split(sample_rng, jax.local_device_count())
427
      test_batch = jax.tree_map(lambda x: x._numpy(), test_batch)  # pylint: disable=protected-access
428
      test_batch, _ = self.preprocess_batch(test_batch, batch_size, dataset,
429
                                            use_vertical)
430
      if test_batch is None or (conditional == "none" and
431
                                idx == sample_step_num):
432
        break
433
      test_batch = tohost(test_batch["targets"])
434
      if use_vertical:
435
        test_label = jnp.full((batch_size, 1), vertical_idx)
436
      else:
437
        test_label = None
438

439
      # All mask symbols.
440
      if conditional != "none":
441
        # asset_num = jnp.sum(test_batch > 0, axis=1, keepdims=True) // 5
442
        asset_num = jnp.sum(
443
            test_batch > 0, axis=1, keepdims=True) // self.total_dim
444
      else:
445
        asset_num = jnp.array(asset_num, dtype="int32")[Ellipsis, None]
446
      # element_num = asset_num * 5
447
      element_num = asset_num * self.total_dim
448
      masked_batch = jnp.full_like(test_batch, 3)
449

450
      position_ids = jnp.arange(masked_batch.shape[-1])[None, :]
451
      # is_asset = position_ids % 5 == 0
452
      # is_size = (position_ids % 5 == 1) | (position_ids % 5 == 2)
453
      is_asset = position_ids % self.total_dim == 0
454
      is_size = functools.reduce(lambda x, y: x | y, [
455
          position_ids % self.total_dim == i
456
          for i in range(1, self.layout_dim + 1)
457
      ])
458
      if conditional == "a+s":
459
        masked_batch = jnp.where(is_asset | is_size, test_batch, masked_batch)
460
      elif conditional == "a":
461
        masked_batch = jnp.where(is_asset, test_batch, masked_batch)
462
      # Pads masked batch.
463
      masked_batch = jnp.where(position_ids >= element_num, 0, masked_batch)
464
      masked_batch = common_utils.shard(masked_batch)
465

466
      p_rng = jax.random.split(rng, jax.local_device_count())
467
      test_label = common_utils.shard(test_label)
468
      sample_layouts = p_generate_batch(masked_batch, test_label, p_rng, state)
469
      sample_layouts = tohost(sample_layouts)
470
      generated_sample_list.append(sample_layouts - offset)
471
      real_sample_list.append(test_batch - offset)
472
    generated_samples = jnp.concatenate(generated_sample_list, axis=0)
473
    real_samples = jnp.concatenate(real_sample_list, axis=0)
474
    return generated_samples, real_samples
475

476
  def sample_step(self, rng, state, model_dict, pos_info):
477
    """Samples layouts just for visualization during training."""
478
    pass
479

480
  def incremental_decode(self,
481
                         rng,
482
                         variables,
483
                         model,
484
                         pos_info,
485
                         batch=None,
486
                         label=None,
487
                         iterative_nums=None,
488
                         conditional="none",
489
                         logits_mask=None
490
                         ):
491
    """Feeds the inputs sequentially to decode non-autoregressively."""
492

493
    def tokens_to_logits(masked_batch):
494
      logits = model().apply(variables, masked_batch, label, deterministic=True)
495
      return logits
496
    seqs = layout_bert_fast_decode.decode(batch, tokens_to_logits,
497
                                          self.config.sampling_method, rng,
498
                                          logits_mask,
499
                                          iterative_nums=iterative_nums,
500
                                          layout_dim=self.layout_dim)
501
    return seqs
502

503
  def sample_one_batch(self, batch, label, rng, state, model_dict, pos_info,
504
                       iterative_num, conditional, logits_mask):
505
    """Samples one batch for eval."""
506
    model = model_dict["model"]
507
    variables = {"params": state.optimizer.target}
508
    variables.update(state.model_state)
509

510
    x = self.incremental_decode(rng, variables, model, pos_info, batch, label,
511
                                iterative_num, conditional,
512
                                logits_mask)
513
    return x
514

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.