google-research
432 строки · 16.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Layout Base Trainer."""
17
18import abc
19import functools
20import os
21from typing import Dict
22
23from absl import logging
24from clu import metric_writers
25from clu import periodic_actions
26import flax
27from flax import linen as nn
28import flax.jax_utils as flax_utils
29from flax.training import common_utils
30import jax
31import jax.numpy as jnp
32from layout_blt import input_pipeline
33from layout_blt.utils import metrics
34from layout_blt.utils import task_manager
35import ml_collections
36import numpy as np
37import tensorflow as tf
38
39
40class LayoutBaseTrainer(abc.ABC):
41"""Base Trainer for layout generation."""
42
43def __init__(self, config, workdir):
44self.config = config
45self.workdir = workdir
46self.rng = jax.random.PRNGKey(config.seed)
47self.dtype, self.data_dtype = self.get_dtype()
48self.layout_dim = self.config.layout_dim
49self.total_dim = self.layout_dim * 2 + 1
50
51def get_dtype(self):
52if self.config.dtype == "bfloat16":
53return jnp.bfloat16, tf.bfloat16
54else:
55return jnp.float32, tf.float32
56
57def merge_model_state(self, state):
58if jax.tree_leaves(state.model_state):
59cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "x"), "x")
60return state.replace(model_state=cross_replica_mean(state.model_state))
61else:
62return state
63
64@abc.abstractmethod
65def create_train_state(
66self,
67rng,
68inputs,
69):
70pass
71
72def create_optimizer(self):
73"""Creates optimizers separated for weights using decay.
74
75Returns:
76An instance of `Optimizer`.
77"""
78opt_def = flax.optim.Adam(
79learning_rate=self.config.optimizer.lr,
80beta1=self.config.optimizer.beta1,
81beta2=self.config.optimizer.beta2,
82weight_decay=self.config.optimizer.weight_decay)
83return opt_def
84
85def make_mask(self, vocab_size, pos_info, seq_len, layout_dim):
86"""Creates masking for logits at each training step and token offset.
87
88Our vocabulary is the combination of special symbols, asset classes,
89positions and sizes. At each step, only a part of vocabulary is possible.
90For example, the first step is asset type and only the asset candidates
91could be generated.
92
93Args:
94vocab_size: vocabulary size.
95pos_info: start indexs and number of candidates for each vocabulary
96segment. For example, in the following sample, [[2, 3], [6, 2]],
97denotes that for the first segment, its start index in the vocab is 2
98and there are 3 elements in the first segment.
99seq_len: the total length of input sequence.
100layout_dim: the layout dimension.
101Returns:
102asset_logit_masks: [1, seq_len, vocab_size]: logits mask for each step.
103asset_offset: [1, seq_len]: offset to map the output token ids back to
104its original ids.
105"""
106total_dim = layout_dim * 2 + 1
107logit_masks = []
108offset = jnp.array([pi[0] for pi in pos_info])
109offset = jnp.expand_dims(offset, 0)
110
111asset_offset = jnp.tile(offset, (1, seq_len // total_dim))
112# In our current model, the order of asset reprentation is
113# [asset, width, height, x, y]. We create masks for each of them.
114for idx, pi in enumerate(pos_info):
115# Creates a logit mask for the current segment. The logit shape from model
116# is [batch size, seq_len, vocab_size]. At a given step, the logit shape
117# is [batch size, 1, vocab_size], since for a given position, all logits
118# in the batch has the same masking, we just create a [1, 1, vocab_size]
119# mask which can broadcast to the whole batch.
120logit_mask = jnp.ones((1, 1, vocab_size))
121# pi[1] denotes the number of elements in the current segment.
122# The shape of pos_mask is [1, 1, #elements in this current segument].
123# For example, we have four possible asset classes, the pos_mask should be
124# [1, 1, 4].
125pos_mask = jnp.zeros((1, 1, pi[1]))
126# pi[0] means the start index of the current segment in the vocabulary.
127# Here, we update index [pi[0]: pi[0] + pi[1]] to zero.
128logit_mask = jax.lax.dynamic_update_slice(logit_mask, pos_mask,
129(0, 0, pi[0]))
130# At asset positions, we could also produce eos symbol.
131if idx == 0:
132logit_mask = logit_mask.at[:, :, 2].set(0)
133logit_masks.append(logit_mask)
134# We have creates masks for all segments and concatenate them into the mask
135# for a asset. [1, 5, vocab_size]
136logit_masks = jnp.concatenate(logit_masks, axis=1)
137# We extend the above mask to all positions in the sequences.
138asset_logit_masks = jnp.tile(logit_masks, (1, seq_len // total_dim, 1))
139# Concatenates all others positions.
140if seq_len % total_dim > 0:
141asset_logit_masks = jnp.concatenate(
142(asset_logit_masks, logit_masks[:, :(seq_len % total_dim), :]),
143axis=1)
144asset_offset = jnp.concatenate(
145(asset_offset, offset[:, :(seq_len % total_dim)]), axis=1)
146
147return asset_logit_masks, asset_offset
148
149def create_learning_rate_scheduler(self, learning_rate=1., warmup_steps=4000):
150"""Creates learning rate scheduler for transformer.
151
152First increases the learning rate linearly for the first warmup_steps
153training steps, and decreases it thereafter proportionally to the inverse
154square root of the step number.
155Args:
156learning_rate: float, the starting constant for the lr schedule.
157warmup_steps: int, how many steps to warm up (> 0).
158Returns:
159A function to calculate the learing rate given current step.
160"""
161def step_fn(step):
162cur_lr = learning_rate * jnp.minimum(1.0, step / warmup_steps) / jnp.sqrt(
163jnp.maximum(step, warmup_steps))
164return jnp.asarray(cur_lr, dtype=jnp.float32)
165
166return step_fn
167
168def compute_weighted_cross_entropy(self,
169logits,
170targets,
171mask=None,
172label_smoothing=0.0,
173logits_mask=None):
174"""Computes weighted cross entropy between logits and targets.
175
176Args:
177logits: [batch, length, vocab_size] float array.
178targets: [batch, length] int array.
179mask: None or array of shape [batch, length].
180label_smoothing: label smoothing constant, used to determine the on and off
181values.
182logits_mask: [batch, length, vocab_size] float array: logits masking to
183ignore impossible candidates at each step.
184
185Returns:
186loss: float scalar.
187"""
188if logits_mask is not None:
189logits = jnp.where(logits_mask > 0, -1e7, logits)
190if logits.ndim != targets.ndim + 1:
191raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
192(str(logits.shape), str(targets.shape)))
193vocab_size = logits.shape[-1]
194confidence = 1.0 - label_smoothing
195low_confidence = label_smoothing / (vocab_size - 1)
196soft_targets = common_utils.onehot(
197targets, vocab_size, on_value=confidence, off_value=low_confidence)
198
199loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
200# Calculates the best (lowest) possible value of cross entropy, and
201# subtract from the cross entropy loss.
202normalizing_constant = -(
203confidence * jnp.log(confidence) +
204(vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
205loss = loss - normalizing_constant
206
207normalizing_factor = np.prod(targets.shape)
208if mask is not None:
209loss = loss * mask
210normalizing_factor = mask.sum()
211
212return loss.sum(), normalizing_factor
213
214def evaluate(self,
215p_eval_step,
216state,
217rng,
218eval_ds,
219batch_size=0,
220use_vertical=False,
221num_eval_steps=None,
222dataset="RICO"):
223"""Evaluate the target an return a dictionary with the metrics."""
224logging.info("Gathering evaluation metrics.")
225eval_metrics = None
226eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types
227for _, eval_batch in zip(range(num_eval_steps), eval_iter):
228eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access
229eval_batch, eval_label = self.preprocess_batch(eval_batch, batch_size,
230dataset, use_vertical)
231if eval_batch is None:
232continue
233# eval_batch = common_utils.shard(eval_batch)
234metrics_update = p_eval_step(rng, state, eval_batch, eval_label)
235metrics_update = flax.jax_utils.unreplicate(metrics_update)
236eval_metrics = (
237metrics_update
238if eval_metrics is None else eval_metrics.merge(metrics_update))
239return eval_metrics
240
241def train(self):
242"""Training loop."""
243
244tf.io.gfile.makedirs(self.workdir)
245checkpoint_dir = os.path.join(self.workdir, "checkpoints")
246n_devices = jax.local_device_count()
247task_manager_csv = task_manager.TaskManagerWithCsvResults(checkpoint_dir)
248rng, data_rng = jax.random.split(self.rng)
249# Make sure each host uses a different RNG for the training data.
250data_rng = jax.random.fold_in(data_rng, jax.process_index())
251batch_size = self.config.batch_size
252dataset = self.config.dataset
253use_vertical = self.config.use_vertical_info
254
255train_ds, eval_ds, _, vocab_size, pos_info = input_pipeline.get_all_dataset(
256batch_size,
257self.config.dataset_path,
258n_devices,
259add_bos=self.config.autoregressive,
260max_length=self.config.max_length,
261dataset_name=dataset)
262train_ds = train_ds.repeat()
263self.config.vocab_size = vocab_size
264train_iter = iter(train_ds) # pytype: disable=wrong-arg-types
265num_train_steps = self.config.num_train_steps
266logging.info("total_steps=%d", num_train_steps)
267
268rng, model_rng = jax.random.split(rng)
269init_batch = jnp.ones(
270(batch_size, self.config.max_length))
271init_label = jnp.ones((batch_size, 1))
272init_batch = dict(inputs=init_batch, labels=init_label)
273model_dict, state = self.create_train_state(model_rng, init_batch)
274learning_rate_fn = self.create_learning_rate_scheduler(
275learning_rate=self.config.optimizer.lr,
276warmup_steps=self.config.optimizer.warmup_steps)
277state = task_manager.restore_checkpoint(state, checkpoint_dir)
278# Creates logits mask.
279logits_mask, _ = self.make_mask(vocab_size, pos_info,
280self.config.max_length,
281self.config.layout_dim)
282initial_step = int(state.step) + 1
283# Warm-start from a checkpoint
284if initial_step == 1 and self.config.get(
285"checkpoint_path") and self.config.checkpoint_path:
286state = task_manager.restore_from_path(
287state, self.config.checkpoint_path)
288
289state = flax_utils.replicate(state)
290
291writer = metric_writers.create_default_writer(
292self.workdir, just_logging=jax.process_index() > 0)
293if initial_step == 1:
294
295writer.write_hparams(dict(self.config))
296
297logging.info("Starting training loop at step %d.", initial_step)
298hooks = []
299report_progress = periodic_actions.ReportProgress(
300num_train_steps=self.config.num_train_steps, writer=writer)
301if jax.process_index() == 0:
302hooks += [
303report_progress,
304periodic_actions.Profile(num_profile_steps=5, logdir=self.workdir)
305]
306
307p_train_step = jax.pmap(
308functools.partial(
309self.train_step,
310model_dict=model_dict,
311learning_rate_fn=learning_rate_fn,
312logits_mask=logits_mask
313),
314axis_name="batch")
315
316p_eval_step = jax.pmap(
317functools.partial(
318self.eval_step,
319model_dict=model_dict,
320logits_mask=logits_mask
321),
322axis_name="batch")
323
324train_metrics = None
325rng = jax.random.fold_in(rng, jax.process_index())
326rng, train_rng, sample_rng = jax.random.split(rng, 3) # pylint: disable=unused-variable
327
328with metric_writers.ensure_flushes(writer):
329for step in range(initial_step, num_train_steps + 1):
330# `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
331# devices.
332is_last_step = step == self.config.num_train_steps
333with jax.profiler.StepTraceContext("train", step_num=step):
334batch = jax.tree_map(np.asarray, next(train_iter))
335batch, label = self.preprocess_batch(batch, batch_size, dataset,
336use_vertical)
337if batch is None:
338continue
339
340step_rng = jax.random.fold_in(train_rng, step)
341step_rngs = jax.random.split(step_rng, jax.local_device_count())
342state, metrics_update = p_train_step(step_rngs, state, batch, label)
343metric_update = flax.jax_utils.unreplicate(metrics_update)
344train_metrics = (
345metric_update
346if train_metrics is None else train_metrics.merge(metric_update))
347
348# Quick indication that training is happening.
349logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
350for h in hooks:
351h(step)
352
353if step % self.config.log_every_steps == 0:
354logging.info("Finish training step %d.", step)
355writer.write_scalars(step, train_metrics.compute())
356train_metrics = None
357if step % self.config.eval_every_steps == 0 or is_last_step:
358logging.info("eval step")
359state = self.merge_model_state(state)
360sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
361eval_metrics = self.evaluate(p_eval_step, state, sample_rngs, eval_ds,
362batch_size,
363use_vertical,
364self.config.eval_num_steps,
365dataset)
366if eval_metrics is not None:
367writer.write_scalars(step, eval_metrics.compute())
368
369if step % self.config.checkpoint_every_steps == 0 or is_last_step:
370with report_progress.timed("checkpoint"):
371state = self.merge_model_state(state)
372task_manager.save_checkpoint(state, checkpoint_dir)
373logging.info("Finishing training at step %d", num_train_steps)
374if jax.process_index() == 0:
375task_manager_csv.mark_training_done()
376
377def evaluate_metrics(self,
378generated_samples,
379real_samples,
380eos_id=-2,
381conditional="a+s"):
382"""Computing metrics."""
383def convert_format(layouts, eos_id):
384new_layouts = []
385for sample in layouts:
386sample = np.array(sample)
387if np.nonzero(sample == eos_id)[0].shape[0] > 0:
388real_len = np.nonzero(sample == eos_id)[0][0]
389sample = sample[:real_len]
390new_layouts.append(sample.reshape(-1, 5))
391return new_layouts
392generated_samples = convert_format(generated_samples, eos_id)
393
394iou = []
395overlap = []
396alignment = []
397for sample in generated_samples:
398iou.append(metrics.get_layout_iou(sample))
399overlap.append(metrics.get_overlap_index(sample))
400align_loss = metrics.get_alignment_loss(sample)
401if align_loss > 0:
402alignment.append(align_loss)
403def avg(a):
404return sum(a)/len(a)
405rst = {
406"iou": avg(iou),
407"overlap": avg(overlap),
408"alignment": avg(alignment)
409}
410
411if conditional != "unconditional":
412real_samples = convert_format(real_samples, eos_id)
413similarity = metrics.conditional_distance(generated_samples, real_samples,
414conditional)
415rst["similarity"] = similarity
416else:
417diveristy = metrics.diveristy(generated_samples)
418rst["diversity"] = diveristy
419
420return rst
421
422@abc.abstractmethod
423def train_step(self, rng, state, batch, model_dict, logits_mask):
424pass
425
426@abc.abstractmethod
427def eval_step(self, rng, state, batch, model_dict, logits_mask):
428pass
429
430@abc.abstractmethod
431def preprocess_batch(self, batch, batch_size, dataset, use_vertical):
432pass
433