google-research
513 строк · 19.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""BERT Layout Trainer."""
17
18# pylint: disable=g-import-not-at-top
19import functools
20import sys
21sys.path.append("..")
22
23import time
24from typing import Any, Dict, Optional
25
26from absl import logging
27from clu import metrics
28from clu import parameter_overview
29import flax
30import flax.jax_utils as flax_utils
31from flax.training import common_utils
32import jax
33import jax.numpy as jnp
34from layout_blt import input_pipeline
35from layout_blt.nets import na_layout_net
36from layout_blt.trainers import base_trainer
37from layout_blt.utils import layout_bert_fast_decode
38from layout_blt.utils import task_manager
39import numpy as np
40
41
42@flax.struct.dataclass
43class TrainState:
44"""Data structure for checkpoint the model."""
45step: int
46optimizer: flax.optim.Optimizer
47model_state: Optional[Any]
48
49
50@flax.struct.dataclass
51class TrainMetrics(metrics.Collection):
52"""Metrics during training process."""
53loss: metrics.Average.from_output("loss")
54
55
56@flax.struct.dataclass
57class EvalMetrics(metrics.Collection):
58"""Metrics during evaluation process."""
59eval_loss: metrics.Average.from_output("eval_loss")
60
61
62def rmlm_masking(inputs, mask_token, pad_token):
63"""Random length masking function.
64
65Different from standard BERT masking which has a fixed mask ratio. We follow
66masking process in mask predict (https://arxiv.org/abs/1904.09324). A random
67mask ratio between [0, 1) are sampled first and input sequence tokens will
68be masked based this mask ratio.
69
70Args:
71inputs: input layout sequences.
72mask_token: the index of mask token.
73pad_token: the index of pad token.
74Returns:
75dictionary of masked input, original input and mask weights.
76"""
77targets = inputs
78
79rng = jax.random.PRNGKey(jnp.sum(inputs, dtype="int32"))
80
81# Gets positions to leave untouched.
82is_pad = inputs == pad_token
83lens = jnp.sum(~is_pad, axis=-1)
84# Random samples a mask ratio.
85mask_rate = 1 - jax.random.uniform(rng, lens.shape)
86# Obtains the ceiling of the lens to make sure we can mask at least one token.
87mask_lens = jax.lax.ceil(lens * mask_rate)
88# Positions to mask.
89rng, subrng = jax.random.split(rng)
90# Randomly generates the mask score uniformly.
91should_mask = jax.random.uniform(subrng, shape=inputs.shape)
92# Doesn't mask out padding.
93should_mask = jnp.where(is_pad, 2., should_mask)
94# should_mask = jnp.where(is_pad | (~target_mask), 2., should_mask)
95
96sorted_should_mask = jnp.sort(should_mask, axis=-1)
97
98# Obtains the cutoff score for the mask lens.
99cut_off = jnp.take_along_axis(
100sorted_should_mask, jnp.expand_dims(mask_lens-1, 1), axis=-1)
101cut_off = jnp.repeat(cut_off, inputs.shape[1], axis=1)
102
103# Scores smaller than the cutoff will be masked.
104should_mask = jnp.where(should_mask <= cut_off, 1., 0.)
105
106# Full array of MASK tokens
107fullmask = jnp.full_like(inputs, mask_token)
108
109# Only replace positions where `should_mask`
110masked_inputs = jnp.where(should_mask, fullmask, inputs)
111weights = should_mask
112return dict(masked_inputs=masked_inputs, targets=targets, weights=weights)
113
114
115def attribute_random_masking(inputs, mask_token, pad_token, layout_dim):
116"""Attribute-wise masking process..
117
118Different from standard BERT masking which has a fixed mask ratio. Each time,
119we only mask one of three attributes (category, size and position), then a
120random mask ratio between [0, 1) are sampled and this attirbute position
121tokens will be masked based this mask ratio.
122
123Args:
124inputs: input layout sequences.
125mask_token: the index of mask token.
126pad_token: the index of pad token.
127layout_dim: the dimension of layout.
128Returns:
129dictionary of masked input, original input and mask weights.
130"""
131targets = inputs
132total_dim = layout_dim * 2 + 1
133
134rng = jax.random.PRNGKey(jnp.sum(inputs, dtype="int32"))
135
136# Gets positions to leave untouched.
137is_pad = inputs == pad_token
138position_ids = jnp.arange(inputs.shape[-1])[None, :]
139is_asset = position_ids % total_dim == 0
140# is_size = (position_ids % 5 == 1) | (position_ids % 5 == 2)
141# is_position = (position_ids % 5 == 3) | (position_ids % 5 == 4)
142is_size = functools.reduce(
143lambda x, y: x | y,
144[position_ids % total_dim == i for i in range(1, layout_dim + 1)])
145is_position = functools.reduce(
146lambda x, y: x | y,
147[position_ids % total_dim == i for i in range(layout_dim + 1, total_dim)])
148# three steps masking
149rand = jax.random.uniform(rng, (inputs.shape[0], 1))
150
151target_mask = (~is_pad) & is_asset
152target_mask = jnp.where(
153jnp.logical_and(rand >= 0.2, rand < 0.4),
154(is_asset | is_size) & (~is_pad), target_mask)
155# target_mask = jnp.where(
156# jnp.logical_and(rand >= 0.4, rand < 0.6),
157# (is_asset | is_position) & (~is_pad), target_mask)
158target_mask = jnp.where(rand >= 0.4, ~is_pad, target_mask)
159should_mask = target_mask
160
161# Full array of MASK tokens
162fullmask = jnp.full_like(inputs, mask_token)
163fullmask = jnp.where(is_pad, pad_token, fullmask)
164
165# Only replace positions where `should_mask`
166pre_masked_inputs = jnp.where(should_mask, inputs, fullmask)
167weights = is_asset & (~is_pad)
168weights = jnp.where(
169jnp.logical_and(rand >= 0.2, rand < 0.4), is_size & (~is_pad), weights)
170weights = jnp.where(
171jnp.logical_and(rand >= 0.4, rand < 0.6), is_position & (~is_pad),
172weights)
173weights = jnp.where(
174jnp.logical_and(rand >= 0.6, rand < 0.8), is_size & (~is_pad), weights)
175weights = jnp.where(rand >= 0.8, is_asset & (~is_pad), weights)
176
177# lens = jnp.sum(target_mask & (~is_pad), axis=-1)
178lens = jnp.sum(weights, axis=-1)
179rng, subrng = jax.random.split(rng)
180mask_rate = 1 - jax.random.uniform(subrng, lens.shape)
181
182# Obtains the ceiling of the lens to make sure we can mask at least one token.
183mask_lens = jax.lax.ceil(lens * mask_rate)
184# Positions to mask.
185rng, subrng = jax.random.split(rng)
186# Randomly generates the mask score uniformly.
187should_mask = jax.random.uniform(subrng, shape=inputs.shape)
188# Doesn't mask out padding.
189should_mask = jnp.where(weights, should_mask, 2.)
190
191sorted_should_mask = jnp.sort(should_mask, axis=-1)
192
193# Obtains the cutoff score for the mask lens.
194cut_off = jnp.take_along_axis(
195sorted_should_mask, jnp.expand_dims(mask_lens-1, 1), axis=-1)
196cut_off = jnp.repeat(cut_off, inputs.shape[1], axis=1)
197
198# Scores smaller than the cutoff will be masked.
199should_mask = jnp.where(should_mask <= cut_off, 1., 0.)
200
201# Full array of MASK tokens
202fullmask = jnp.full_like(inputs, mask_token)
203
204# Only replace positions where `should_mask`
205masked_inputs = jnp.where(should_mask, fullmask, pre_masked_inputs)
206weights = jnp.where(is_pad, 0, should_mask)
207return dict(masked_inputs=masked_inputs, targets=targets, weights=weights)
208
209
210class BERTLayoutTrainer(base_trainer.LayoutBaseTrainer):
211"""BERT-style Layout Trainer."""
212
213def preprocess_batch(self, batch, batch_size, dataset, use_vertical=False):
214label = None
215# When we reach to the end of the dataset iter, the batch size may not be
216# be our expected one. In the case, we simply skip it.
217if batch.shape[0] != batch_size:
218return None, None
219batch = attribute_random_masking(
220batch, mask_token=3, pad_token=0, layout_dim=self.layout_dim)
221batch = common_utils.shard(batch)
222return batch, label
223
224def create_train_state(
225self,
226rng,
227inputs,
228):
229model = functools.partial(
230na_layout_net.NALayoutNet,
231use_vertical=self.config.use_vertical_info,
232vocab_size=self.config.vocab_size,
233hidden_size=self.config.qkv_dim,
234num_hidden_layers=self.config.num_layers,
235num_attention_heads=self.config.num_heads,
236intermediate_size=self.config.mlp_dim,
237pad_token_id=0,
238layout_dim=self.layout_dim)
239param_rng, dropout_rng, rng = jax.random.split(rng, 3)
240model_variables = model().init({
241"params": param_rng,
242"dropout": dropout_rng
243},
244inputs["inputs"],
245inputs["labels"],
246deterministic=False)
247model_state = dict(model_variables)
248model_params = model_state.pop("params")
249logging.info("logging model parameters")
250parameter_overview.log_parameter_overview(model_params)
251optimizer = self.create_optimizer().create(model_params)
252model_dict = dict(model=model)
253train_state = TrainState(
254step=0,
255optimizer=optimizer,
256model_state=model_state)
257return model_dict, train_state
258
259def train_step(self, rng, state, batch, label, learning_rate_fn, model_dict,
260logits_mask):
261"""Perform a single training step.
262
263Args:
264rng: The random seed,
265state: State of the model (optimizer and state).
266batch: Training inputs for this step.
267label: Training input vectical info (always None for now).
268learning_rate_fn: The learning scheduler.
269model_dict: The model used in training.
270logits_mask: Logits mask for each step.
271
272Returns:
273The new model state and dictionary with metrics
274"""
275logging.info("train_step(batch=%s)", batch)
276step = state.step + 1
277lr = learning_rate_fn(state.step)
278model = model_dict["model"]
279
280def loss_fn(params):
281variables = {"params": params}
282variables.update(state.model_state)
283
284logits, new_variables = model().apply(
285{"params": params},
286batch["masked_inputs"],
287labels=label,
288deterministic=False,
289rngs={"dropout": rng},
290mutable=True)
291ce_loss, num_tokens = self.compute_weighted_cross_entropy(
292logits, batch["targets"], batch["weights"],
293self.config.label_smoothing, logits_mask)
294
295loss = ce_loss / num_tokens
296new_model_state = dict(new_variables)
297new_model_state.pop("params")
298return loss, new_model_state
299
300grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
301(loss, new_model_state), grad = grad_fn(state.optimizer.target)
302grad = jax.lax.pmean(grad, "batch")
303new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
304new_state = state.replace(
305step=step, optimizer=new_optimizer, model_state=new_model_state)
306metrics_update = TrainMetrics.gather_from_model_output(loss=loss)
307return new_state, metrics_update
308
309def eval_step(self, rng, state, batch, label, model_dict, logits_mask):
310model = model_dict["model"]
311logits = model().apply(
312{"params": state.optimizer.target},
313batch["masked_inputs"],
314labels=label,
315deterministic=True)
316ce_loss, num_tokens = self.compute_weighted_cross_entropy(
317logits, batch["targets"], batch["weights"], self.config.label_smoothing,
318logits_mask)
319
320loss = ce_loss / num_tokens
321metrics_update = EvalMetrics.gather_from_model_output(eval_loss=loss)
322return metrics_update
323
324def test(self,
325batch_size=1,
326iterative_nums=None,
327conditional="none",
328max_decode_len=128,
329use_vertical=False,
330sample_step_num=10,
331prior=None,
332max_asset_num=22,
333vertical_idx=0):
334"""Runs a test run."""
335rng = jax.random.PRNGKey(self.config.seed)
336np.random.seed(self.config.seed)
337# Make sure each host uses a different RNG.
338rng = jax.random.fold_in(rng, jax.process_index())
339rng, model_rng, data_rng = jax.random.split(rng, 3)
340data_rng = jax.random.fold_in(data_rng, jax.process_index())
341dataset = self.config.dataset
342test_ds, vocab_size, pos_info = input_pipeline.get_dataset(
343batch_size,
344self.config.dataset_path,
345jax.local_device_count(),
346"test.json",
347max_decode_len,
348add_bos=False,
349dataset_name=dataset)
350logits_mask, offset = self.make_mask(vocab_size, pos_info,
351max_decode_len,
352self.layout_dim)
353init_batch = jnp.ones(
354(batch_size, max_decode_len))
355init_label = jnp.ones((batch_size, 1))
356init_batch = dict(inputs=init_batch, labels=init_label)
357model_dict, state = self.create_train_state(model_rng, init_batch)
358ckpt_path = self.config.test_checkpoint_dir
359state = task_manager.restore_checkpoint(state, ckpt_path)
360state = flax_utils.replicate(state)
361
362sample_one_batch_fn = functools.partial(
363self.sample_one_batch,
364pos_info=pos_info,
365iterative_num=iterative_nums,
366conditional=conditional,
367logits_mask=logits_mask)
368p_generate_batch = jax.pmap(
369functools.partial(
370sample_one_batch_fn,
371model_dict=model_dict,
372),
373axis_name="batch")
374test_iter = iter(test_ds) # pytype: disable=wrong-arg-types
375generated_sample_list, real_sample_list = [], []
376assert iterative_nums is not None and len(iterative_nums) == 3
377iterative_nums = np.array(iterative_nums)
378def tohost(x):
379"""Collect batches from all devices to host and flatten batch dimensions."""
380n_device, n_batch, *remaining_dims = x.shape
381return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims))
382
383if conditional == "none":
384total_time = 0.
385for idx in range(sample_step_num):
386if use_vertical:
387test_label = jnp.full((batch_size, 1), vertical_idx)
388else:
389test_label = None
390asset_num = np.random.choice(max_asset_num, batch_size, p=prior) + 1
391rng, sample_rng = jax.random.split(rng, 2)
392p_rng = jax.random.split(sample_rng, jax.local_device_count())
393
394# All mask symbols.
395asset_num = jnp.array(asset_num, dtype="int32")[Ellipsis, None]
396# element_num = asset_num * 5
397element_num = asset_num * self.total_dim
398masked_batch = jnp.full((batch_size, 128), 3)
399
400position_ids = jnp.arange(masked_batch.shape[-1])[None, :]
401# Pads masked batch.
402masked_batch = jnp.where(position_ids >= element_num, 0, masked_batch)
403masked_batch = common_utils.shard(masked_batch)
404test_label = common_utils.shard(test_label)
405
406p_rng = jax.random.split(rng, jax.local_device_count())
407start_time = time.time()
408
409sample_layouts = p_generate_batch(masked_batch, test_label, p_rng,
410state)
411total_time += time.time() - start_time
412start_time = time.time()
413sample_layouts = tohost(sample_layouts)
414generated_sample_list.append(sample_layouts - offset)
415
416generated_samples = jnp.concatenate(generated_sample_list, axis=0)
417real_samples = None
418logging.info("decoding time: (%.4f)", total_time)
419return generated_samples, real_samples
420
421for idx, test_batch in enumerate(test_iter):
422if idx >= sample_step_num:
423break
424asset_num = np.random.choice(max_asset_num, batch_size, p=prior) + 1
425rng, sample_rng = jax.random.split(rng, 2)
426p_rng = jax.random.split(sample_rng, jax.local_device_count())
427test_batch = jax.tree_map(lambda x: x._numpy(), test_batch) # pylint: disable=protected-access
428test_batch, _ = self.preprocess_batch(test_batch, batch_size, dataset,
429use_vertical)
430if test_batch is None or (conditional == "none" and
431idx == sample_step_num):
432break
433test_batch = tohost(test_batch["targets"])
434if use_vertical:
435test_label = jnp.full((batch_size, 1), vertical_idx)
436else:
437test_label = None
438
439# All mask symbols.
440if conditional != "none":
441# asset_num = jnp.sum(test_batch > 0, axis=1, keepdims=True) // 5
442asset_num = jnp.sum(
443test_batch > 0, axis=1, keepdims=True) // self.total_dim
444else:
445asset_num = jnp.array(asset_num, dtype="int32")[Ellipsis, None]
446# element_num = asset_num * 5
447element_num = asset_num * self.total_dim
448masked_batch = jnp.full_like(test_batch, 3)
449
450position_ids = jnp.arange(masked_batch.shape[-1])[None, :]
451# is_asset = position_ids % 5 == 0
452# is_size = (position_ids % 5 == 1) | (position_ids % 5 == 2)
453is_asset = position_ids % self.total_dim == 0
454is_size = functools.reduce(lambda x, y: x | y, [
455position_ids % self.total_dim == i
456for i in range(1, self.layout_dim + 1)
457])
458if conditional == "a+s":
459masked_batch = jnp.where(is_asset | is_size, test_batch, masked_batch)
460elif conditional == "a":
461masked_batch = jnp.where(is_asset, test_batch, masked_batch)
462# Pads masked batch.
463masked_batch = jnp.where(position_ids >= element_num, 0, masked_batch)
464masked_batch = common_utils.shard(masked_batch)
465
466p_rng = jax.random.split(rng, jax.local_device_count())
467test_label = common_utils.shard(test_label)
468sample_layouts = p_generate_batch(masked_batch, test_label, p_rng, state)
469sample_layouts = tohost(sample_layouts)
470generated_sample_list.append(sample_layouts - offset)
471real_sample_list.append(test_batch - offset)
472generated_samples = jnp.concatenate(generated_sample_list, axis=0)
473real_samples = jnp.concatenate(real_sample_list, axis=0)
474return generated_samples, real_samples
475
476def sample_step(self, rng, state, model_dict, pos_info):
477"""Samples layouts just for visualization during training."""
478pass
479
480def incremental_decode(self,
481rng,
482variables,
483model,
484pos_info,
485batch=None,
486label=None,
487iterative_nums=None,
488conditional="none",
489logits_mask=None
490):
491"""Feeds the inputs sequentially to decode non-autoregressively."""
492
493def tokens_to_logits(masked_batch):
494logits = model().apply(variables, masked_batch, label, deterministic=True)
495return logits
496seqs = layout_bert_fast_decode.decode(batch, tokens_to_logits,
497self.config.sampling_method, rng,
498logits_mask,
499iterative_nums=iterative_nums,
500layout_dim=self.layout_dim)
501return seqs
502
503def sample_one_batch(self, batch, label, rng, state, model_dict, pos_info,
504iterative_num, conditional, logits_mask):
505"""Samples one batch for eval."""
506model = model_dict["model"]
507variables = {"params": state.optimizer.target}
508variables.update(state.model_state)
509
510x = self.incremental_decode(rng, variables, model, pos_info, batch, label,
511iterative_num, conditional,
512logits_mask)
513return x
514