google-research
374 строки · 13.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Transformer for Layout Trainer."""
17
18# pylint: disable=g-import-not-at-top
19# pylint: disable=g-bad-import-order
20
21import sys22sys.path.append("..")23
24import functools25from typing import Any, Dict, Optional, Tuple26
27from absl import logging28from clu import metrics29from clu import parameter_overview30import flax31import flax.jax_utils as flax_utils32from flax.training import common_utils33import jax34import jax.numpy as jnp35from layout_blt import input_pipeline36from layout_blt.nets import transformer37from layout_blt.trainers import base_trainer38from layout_blt.utils import layout_fast_decode39from layout_blt.utils import task_manager40import numpy as np41
42
43@flax.struct.dataclass44class TrainState:45"""Data structure for checkpoint the model."""46step: int47optimizer: flax.optim.Optimizer48model_state: Optional[Any]49
50
51@flax.struct.dataclass52class TrainMetrics(metrics.Collection):53"""Metrics during training process."""54loss: metrics.Average.from_output("loss")55
56
57@flax.struct.dataclass58class EvalMetrics(metrics.Collection):59"""Metrics during evaluation process."""60eval_loss: metrics.Average.from_output("eval_loss")61
62
63class TransformerTrainer(base_trainer.LayoutBaseTrainer):64"""Transformer for Layout Trainer."""65
66def preprocess_batch(self, batch, batch_size, dataset, use_vertical=False):67label = None68# When we reach to the end of the dataset iter, the batch size may not be69# be our expected one. In the case, we simply skip it.70if batch.shape[0] != batch_size:71return None, None72batch = common_utils.shard(batch)73return batch, label74
75def create_train_state(76self,77rng,78inputs,79):80model = functools.partial(81transformer.TransformerDecoder, config=self.config)82param_rng, latent_rng = jax.random.split(rng, 2)83model_variables = model(deterministic=True).init(param_rng,84inputs["inputs"],85inputs["labels"],86latent_rng)87
88model_state = dict(model_variables)89model_params = model_state.pop("params")90logging.info("logging model parameters")91parameter_overview.log_parameter_overview(model_params)92optimizer = self.create_optimizer().create(model_params)93model_dict = dict(model=model)94train_state = TrainState(95step=0,96optimizer=optimizer,97model_state=model_state)98return model_dict, train_state99
100def train_step(101self,102rng,103state,104batch,105label,106learning_rate_fn,107model_dict,108logits_mask
109):110"""Perform a single training step.111
112Args:
113rng: The random seed,
114state: State of the model (optimizer and state).
115batch: Training inputs for this step.
116label: Training input vectical info (always None for now).
117learning_rate_fn: The learning scheduler.
118model_dict: The model used in training.
119logits_mask: Logits mask for each step.
120
121Returns:
122The new model state and dictionary with metrics
123"""
124logging.info("train_step(batch=%s)", batch)125step = state.step + 1126lr = learning_rate_fn(state.step)127model = model_dict["model"]128dec_target = batch[:, 1:]129if logits_mask is not None:130logits_mask = logits_mask[:, :-1, :]131pad_mask = jnp.where(dec_target > 0, 1, 0).astype(jnp.float32)132def loss_fn(params):133dropout_rng, latent_rng = jax.random.split(rng)134variables = {"params": params}135variables.update(state.model_state)136(logits, _), new_variables = model().apply(137{"params": params},138batch,139label,140latent_rng,141rngs={"dropout": dropout_rng},142mutable=True)143recon_loss, num_tokens = self.compute_weighted_cross_entropy(144logits, dec_target, pad_mask, self.config.label_smoothing,145logits_mask)146recon_loss = recon_loss / num_tokens147loss = recon_loss148new_model_state = dict(new_variables)149new_model_state.pop("params")150return loss, (recon_loss, new_model_state)151
152grad_fn = jax.value_and_grad(loss_fn, has_aux=True)153(loss, (recon_loss,154new_model_state)), grad = grad_fn(state.optimizer.target)155del recon_loss156grad = jax.lax.pmean(grad, "batch")157new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)158new_state = state.replace(159step=step, optimizer=new_optimizer, model_state=new_model_state)160metrics_update = TrainMetrics.gather_from_model_output(loss=loss)161return new_state, metrics_update162
163def eval_step(self, rng, state, batch, label, model_dict, logits_mask):164model = model_dict["model"]165dec_target = batch[:, 1:]166logits_mask = logits_mask[:, :-1, :]167pad_mask = jnp.where(dec_target > 0, 1, 0).astype(jnp.float32)168(logits, _) = model(deterministic=True).apply(169{"params": state.optimizer.target},170batch,171label,172rng)173recon_loss, num_tokens = self.compute_weighted_cross_entropy(174logits, dec_target, pad_mask, self.config.label_smoothing, logits_mask)175recon_loss = recon_loss / num_tokens176loss = recon_loss177metrics_update = EvalMetrics.gather_from_model_output(eval_loss=loss)178return metrics_update179
180def test(self,181sampling_method="topp",182conditional="none",183eos_id=2,184batch_size=1,185sample_step_num=1,186max_decode_len=128,187use_vertical=False,188vertical_idx=0):189"""Runs a test run.190
191Args:
192sampling_method: str: how to generate the current token.
193conditional: str: none: uncondtional generation, a: asset condtional
194generation, a+s: asset + size condtional generation.
195eos_id: int: the index of eos symbol.
196batch_size: int: batch size of generation at one time.
197sample_step_num: int: how many batches to generate.
198max_decode_len: int: the maximum number of tokens during generation.
199use_vertical: bool: whether use vertical information (always False).
200vertical_idx: int: vertical index.
201Returns:
202generated_samples: [sample_step_num*batch_size, max_decode_len]:
203generated layouts.
204real_samples: [sample_step_num*batch_size, max_decode_len]: real layouts.
205"""
206assert batch_size % jax.local_device_count() == 0207rng = jax.random.PRNGKey(self.config.seed)208# Make sure each host uses a different RNG.209rng = jax.random.fold_in(rng, jax.process_index())210rng, model_rng, data_rng = jax.random.split(rng, 3)211data_rng = jax.random.fold_in(data_rng, jax.process_index())212dataset = self.config.dataset213
214test_ds, vocab_size, pos_info = input_pipeline.get_dataset(215batch_size,216self.config.dataset_path,217jax.local_device_count(),218"test.json",219max_length=max_decode_len,220dataset_name=dataset)221
222init_batch = jnp.ones((batch_size, self.config.max_length))223init_label = jnp.ones((batch_size, 1))224init_batch = dict(inputs=init_batch, labels=init_label)225model_dict, state = self.create_train_state(model_rng, init_batch)226ckpt_path = self.config.test_checkpoint_dir227state = task_manager.restore_checkpoint(state, ckpt_path)228state = flax_utils.replicate(state)229sample_one_batch_fn = functools.partial(230self.sample_one_batch,231pos_info=pos_info,232batch_size=batch_size//jax.local_device_count(),233conditional=conditional,234eos_id=eos_id,235max_decode_len=max_decode_len,236sampling_method=sampling_method)237p_generate_batch = jax.pmap(238functools.partial(239sample_one_batch_fn,240model_dict=model_dict,241),242axis_name="batch")243
244test_iter = iter(test_ds) # pytype: disable=wrong-arg-types245def tohost(x):246"""Collect batches from all devices to host and flatten batch dimensions."""247n_device, n_batch, *remaining_dims = x.shape248return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims))249
250_, sample_offset = self.make_mask(vocab_size, pos_info, max_decode_len,251self.layout_dim)252generated_sample_list, real_sample_list = [], []253for idx, test_batch in enumerate(test_iter):254if idx == sample_step_num:255break256rng, sample_rng = jax.random.split(rng, 2)257p_rng = jax.random.split(sample_rng, jax.local_device_count())258test_batch = jax.tree_map(lambda x: x._numpy(), test_batch) # pylint: disable=protected-access259test_batch, test_label = self.preprocess_batch(test_batch, batch_size,260dataset, use_vertical)261# For uncondtional generation, we stop the process according to the262# sampel_step_num, otherwise, we use the whole test set.263if test_batch is None or (conditional == "none" and264idx == sample_step_num):265break266
267if conditional == "none":268if use_vertical:269test_label = jnp.full_like(test_label, vertical_idx)270sample_layouts = p_generate_batch(None, p_rng, state, label=test_label)271else:272sample_layouts = p_generate_batch(test_batch[Ellipsis, 1:], p_rng, state,273label=test_label)274# We do not need bos symbol.275sample_layouts = tohost(sample_layouts)[Ellipsis, 1:]276real_layouts = None277if test_batch is not None:278real_layouts = tohost(test_batch)[Ellipsis, 1:]279_, real_offset = self.make_mask(self.config.vocab_size, pos_info,280real_layouts.shape[-1], self.layout_dim)281real_layouts = real_layouts - real_offset282generated_sample_list.append(sample_layouts - sample_offset[Ellipsis, :-1])283real_sample_list.append(real_layouts)284generated_samples = jnp.concatenate(generated_sample_list, axis=0)285real_samples = jnp.concatenate(real_sample_list, axis=0)286
287return generated_samples, real_samples288
289def sample_step(self, rng, state, model_dict, pos_info):290"""Samples layouts just for visualization during training."""291# TODO(xiang): sample some images during training.292return None293
294def fast_decode(self,295rng,296variables,297model,298pos_info,299label=None,300batch=None,301batch_size=1,302conditional="none",303eos_id=2,304max_decode_len=100,305sampling_method="topp"):306"""Fast layout generation deocoding method.307
308Args:
309rng: jax random state.
310variables: model parameters.
311model: layouu generation model.
312pos_info: vocabulary segmentation infomation.
313label: vertical information (always None for now).
314batch: real layouts batch for conditional generation.
315batch_size: number of layouts to generate one time.
316conditional: conditioanl type.
317eos_id: index of eos symbol.
318max_decode_len: maximum number of tokens to generate.
319sampling_method: sampling method during generation (argmax or sampling).
320Returns:
321seqs: generated layouts.
322"""
323eval_model = model(deterministic=True, is_train=False)324init_rng, rng, latent_rng = jax.random.split(rng, 3)325init_batch = jnp.ones((batch_size, max_decode_len))326init_label = jnp.ones((batch_size, 1))327initial_vars = eval_model.init(init_rng, init_batch, init_label, latent_rng)328cache_dict, _ = initial_vars.pop("params")329
330def tokens_to_logits(xi, cache_dict, decode_step, initial_z):331logits, cache_dict = eval_model.apply(332{333**variables,334**cache_dict335},336xi,337label,338initial_z,339decode_step,340mutable=["cache"],341method=transformer.TransformerDecoder.decode)342return logits, cache_dict343
344logit_masks, _ = self.make_mask(self.config.vocab_size, pos_info,345self.total_dim, self.layout_dim)346# BOS symbol.347initial_z = jax.random.normal(rng, (batch_size, eval_model.config.emb_dim))348tokens_to_logits_fn = functools.partial(349tokens_to_logits, initial_z=initial_z)350batch = init_batch if batch is None else batch351
352seqs = layout_fast_decode.decode(353batch,354cache_dict,355tokens_to_logits_fn,356max_decode_len=max_decode_len,357sampling_method=sampling_method,358rng=rng,359logit_masks=logit_masks,360conditional=conditional)361return seqs362
363def sample_one_batch(self, batch, rng, state, model_dict, pos_info, label,364batch_size, conditional, eos_id, max_decode_len,365sampling_method):366"""Samples one batch for eval."""367model = model_dict["model"]368variables = {"params": state.optimizer.target}369variables.update(state.model_state)370
371x = self.fast_decode(rng, variables, model, pos_info, label, batch,372batch_size, conditional, eos_id, max_decode_len,373sampling_method)374return x375