CSS-LM
642 строки · 25.9 Кб
1"""Tensorflow trainer class."""
2
3import datetime
4import logging
5import math
6import os
7import sys
8import warnings
9from typing import Callable, Dict, Optional, Tuple
10
11import numpy as np
12import tensorflow as tf
13from packaging.version import parse
14
15from .modeling_tf_utils import TFPreTrainedModel
16from .optimization_tf import GradientAccumulator, create_optimizer
17from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available, set_seed
18from .training_args_tf import TFTrainingArguments
19
20
21if is_wandb_available():
22import wandb
23
24
25logger = logging.getLogger(__name__)
26
27
28if parse(tf.__version__).release < (2, 2, 0):
29logger.info(
30"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
31tf.__version__
32)
33)
34sys.exit(1)
35
36
37class TFTrainer:
38"""
39TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
40optimized for 🤗 Transformers.
41
42Args:
43model (:class:`~transformers.TFPreTrainedModel`):
44The model to train, evaluate or use for predictions.
45args (:class:`~transformers.TFTrainingArguments`):
46The arguments to tweak training.
47train_dataset (:class:`~tf.data.Dataset`, `optional`):
48The dataset to use for training.
49eval_dataset (:class:`~tf.data.Dataset`, `optional`):
50The dataset to use for evaluation.
51compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
52The function that will be used to compute metrics at evaluation. Must take a
53:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
54prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
55When performing evaluation and predictions, only returns the loss.
56tb_writer (:obj:`tf.summary.SummaryWriter`, `optional`):
57Object to write to TensorBoard.
58optimizers (:obj:`Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]`, `optional`):
59A tuple containing the optimizer and the scheduler to use. The optimizer default to an instance of
60:class:`tf.keras.optimizers.Adam` if :obj:`args.weight_decay_rate` is 0 else an instance of
61:class:`~transformers.AdamWeightDecay`. The scheduler will default to an instance of
62:class:`tf.keras.optimizers.schedules.PolynomialDecay` if :obj:`args.num_warmup_steps` is 0 else
63an instance of :class:`~transformers.WarmUp`.
64"""
65
66def __init__(
67self,
68model: TFPreTrainedModel,
69args: TFTrainingArguments,
70train_dataset: Optional[tf.data.Dataset] = None,
71eval_dataset: Optional[tf.data.Dataset] = None,
72compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
73prediction_loss_only=False,
74tb_writer: Optional[tf.summary.SummaryWriter] = None,
75optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (
76None,
77None,
78),
79):
80self.model = model
81self.args = args
82self.train_dataset = train_dataset
83self.eval_dataset = eval_dataset
84self.compute_metrics = compute_metrics
85self.prediction_loss_only = prediction_loss_only
86self.optimizer, self.lr_scheduler = optimizers
87self.gradient_accumulator = GradientAccumulator()
88self.global_step = 0
89self.epoch_logging = 0
90
91if tb_writer is not None:
92self.tb_writer = tb_writer
93else:
94self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
95
96if is_wandb_available():
97self.setup_wandb()
98elif os.environ.get("WANDB_DISABLED") != "true":
99logger.info(
100"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
101"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
102)
103
104set_seed(self.args.seed)
105
106def get_train_tfdataset(self) -> tf.data.Dataset:
107"""
108Returns the training :class:`~tf.data.Dataset`.
109
110Subclass and override this method if you want to inject some custom behavior.
111"""
112if self.train_dataset is None:
113raise ValueError("Trainer: training requires a train_dataset.")
114
115self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
116self.num_train_examples = tf.data.experimental.cardinality(self.train_dataset).numpy()
117
118if self.num_train_examples < 0:
119raise ValueError("The training dataset must have an asserted cardinality")
120
121ds = (
122self.train_dataset.repeat()
123.shuffle(self.num_train_examples, seed=self.args.seed)
124.batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
125.prefetch(tf.data.experimental.AUTOTUNE)
126)
127
128return self.args.strategy.experimental_distribute_dataset(ds)
129
130def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
131"""
132Returns the evaluation :class:`~tf.data.Dataset`.
133
134Args:
135eval_dataset (:class:`~tf.data.Dataset`, `optional`):
136If provided, will override `self.eval_dataset`.
137
138Subclass and override this method if you want to inject some custom behavior.
139"""
140if eval_dataset is None and self.eval_dataset is None:
141raise ValueError("Trainer: evaluation requires an eval_dataset.")
142
143eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
144num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
145
146if num_examples < 0:
147raise ValueError("The training dataset must have an asserted cardinality")
148
149approx = math.floor if self.args.dataloader_drop_last else math.ceil
150steps = approx(num_examples / self.args.eval_batch_size)
151ds = (
152eval_dataset.repeat()
153.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
154.prefetch(tf.data.experimental.AUTOTUNE)
155)
156
157return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
158
159def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
160"""
161Returns a test :class:`~tf.data.Dataset`.
162
163Args:
164test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
165
166Subclass and override this method if you want to inject some custom behavior.
167"""
168
169num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
170
171if num_examples < 0:
172raise ValueError("The training dataset must have an asserted cardinality")
173
174approx = math.floor if self.args.dataloader_drop_last else math.ceil
175steps = approx(num_examples / self.args.eval_batch_size)
176ds = (
177test_dataset.repeat()
178.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
179.prefetch(tf.data.experimental.AUTOTUNE)
180)
181
182return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
183
184def create_optimizer_and_scheduler(self, num_training_steps: int):
185"""
186Setup the optimizer and the learning rate scheduler.
187
188We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
189TFTrainer's init through :obj:`optimizers`, or subclass and override this method.
190"""
191if not self.optimizer and not self.lr_scheduler:
192self.optimizer, self.lr_scheduler = create_optimizer(
193self.args.learning_rate,
194num_training_steps,
195self.args.warmup_steps,
196adam_beta1=self.args.adam_beta1,
197adam_beta2=self.args.adam_beta2,
198adam_epsilon=self.args.adam_epsilon,
199weight_decay_rate=self.args.weight_decay,
200)
201
202def setup_wandb(self):
203"""
204Setup the optional Weights & Biases (`wandb`) integration.
205
206One can subclass and override this method to customize the setup if needed. Find more information
207`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
208
209Environment:
210WANDB_PROJECT:
211(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
212WANDB_DISABLED:
213(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
214"""
215if hasattr(self, "_setup_wandb"):
216warnings.warn(
217"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
218FutureWarning,
219)
220return self._setup_wandb()
221
222logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
223wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
224
225def prediction_loop(
226self,
227dataset: tf.data.Dataset,
228steps: int,
229num_examples: int,
230description: str,
231prediction_loss_only: Optional[bool] = None,
232) -> PredictionOutput:
233"""
234Prediction/evaluation loop, shared by :func:`~transformers.TFTrainer.evaluate` and
235:func:`~transformers.TFTrainer.predict`.
236
237Works both with or without labels.
238"""
239if hasattr(self, "_prediction_loop"):
240warnings.warn(
241"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
242FutureWarning,
243)
244return self._prediction_loop(
245dataset, steps, num_examples, description, prediction_loss_only=prediction_loss_only
246)
247
248prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
249
250logger.info("***** Running %s *****", description)
251logger.info(" Num examples = %d", num_examples)
252logger.info(" Batch size = %d", self.args.eval_batch_size)
253
254label_ids: np.ndarray = None
255preds: np.ndarray = None
256self.eval_loss = tf.keras.metrics.Sum()
257
258# Reset the past mems state at the beginning of the evaluation if necessary.
259if self.args.past_index >= 0:
260self._past = None
261
262for step, batch in enumerate(dataset):
263logits = self.distributed_prediction_steps(batch)
264_, labels = batch
265
266if not prediction_loss_only:
267if isinstance(logits, tuple):
268logits = logits[0]
269
270if isinstance(labels, tuple):
271labels = labels[0]
272
273if self.args.n_replicas > 1:
274for val in logits.values:
275if preds is None:
276preds = val.numpy()
277else:
278preds = np.append(preds, val.numpy(), axis=0)
279
280for val in labels.values:
281if label_ids is None:
282label_ids = val.numpy()
283else:
284label_ids = np.append(label_ids, val.numpy(), axis=0)
285else:
286if preds is None:
287preds = logits.numpy()
288else:
289preds = np.append(preds, logits.numpy(), axis=0)
290
291if label_ids is None:
292label_ids = labels.numpy()
293else:
294label_ids = np.append(label_ids, labels.numpy(), axis=0)
295
296if step == steps:
297break
298
299if self.compute_metrics is not None and preds is not None and label_ids is not None:
300metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
301else:
302metrics = {}
303
304metrics["eval_loss"] = self.eval_loss.result().numpy() / (steps * self.args.eval_batch_size)
305
306for key in list(metrics.keys()):
307if not key.startswith("eval_"):
308metrics[f"eval_{key}"] = metrics.pop(key)
309
310if self.args.past_index and hasattr(self, "_past"):
311# Clean the state at the end of training
312delattr(self, "_past")
313
314return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
315
316def log(self, logs: Dict[str, float]) -> None:
317"""
318Log :obj:`logs` on the various objects watching training.
319
320Subclass and override this method to inject custom behavior.
321
322Args:
323logs (:obj:`Dict[str, float]`):
324The values to log.
325"""
326if hasattr(self, "_log"):
327warnings.warn(
328"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
329FutureWarning,
330)
331return self._log(logs)
332logs["epoch"] = self.epoch_logging
333
334if self.tb_writer:
335with self.tb_writer.as_default():
336for k, v in logs.items():
337tf.summary.scalar(k, v, step=self.global_step)
338self.tb_writer.flush()
339
340if is_wandb_available():
341wandb.log(logs, step=self.global_step)
342
343output = {**logs, **{"step": self.global_step}}
344
345logger.info(output)
346
347def evaluate(self, eval_dataset: Optional[tf.data.Dataset] = None) -> Dict[str, float]:
348"""
349Run evaluation and returns metrics.
350
351The calling script will be responsible for providing a method to compute metrics, as they are
352task-dependent (pass it to the init :obj:`compute_metrics` argument).
353
354Args:
355eval_dataset (:class:`~tf.data.Dataset`, `optional`):
356Pass a dataset if you wish to override :obj:`self.eval_dataset`.
357
358Returns:
359A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
360"""
361eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
362
363output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
364logs = {**output.metrics}
365logs["epoch"] = self.epoch_logging
366
367self.log(logs)
368
369return output.metrics
370
371def prediction_step(self, features: tf.Tensor, labels: tf.Tensor) -> tf.Tensor:
372"""
373Compute the prediction on features and update the loss with labels.
374
375Subclass and override to inject some custom behavior.
376"""
377per_example_loss, logits = self.run_model(features, labels, False)
378
379self.eval_loss.update_state(per_example_loss)
380
381return logits
382
383@tf.function
384def distributed_prediction_steps(self, batch):
385logits = self.args.strategy.run(self.prediction_step, batch)
386
387return logits
388
389def train(self) -> None:
390"""
391Train method to train the model.
392"""
393train_ds = self.get_train_tfdataset()
394
395if self.args.debug:
396tf.summary.trace_on(graph=True, profiler=True)
397
398self.gradient_accumulator.reset()
399
400if self.args.max_steps > 0:
401t_total = self.args.max_steps
402self.steps_per_epoch = self.args.max_steps
403else:
404approx = math.floor if self.args.dataloader_drop_last else math.ceil
405self.steps_per_epoch = approx(self.num_train_examples / self.total_train_batch_size)
406t_total = self.steps_per_epoch * self.args.num_train_epochs
407
408with self.args.strategy.scope():
409self.create_optimizer_and_scheduler(num_training_steps=t_total)
410iterations = self.optimizer.iterations
411self.global_step = iterations.numpy()
412folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
413ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
414self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
415
416if self.model.ckpt_manager.latest_checkpoint:
417epochs_trained = self.global_step // (self.num_train_examples // self.args.gradient_accumulation_steps)
418steps_trained_in_current_epoch = self.global_step % (
419self.num_train_examples // self.args.gradient_accumulation_steps
420)
421
422logger.info(" Continuing training from checkpoint, will skip to saved global_step")
423logger.info(" Continuing training from epoch %d", epochs_trained)
424logger.info(" Continuing training from global step %d", self.global_step)
425logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
426logger.info(
427"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
428)
429
430ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
431else:
432epochs_trained = 1
433
434tf.summary.experimental.set_step(iterations)
435
436epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
437
438if self.args.fp16:
439policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
440tf.keras.mixed_precision.experimental.set_policy(policy)
441
442with self.tb_writer.as_default():
443tf.summary.text("args", self.args.to_json_string())
444
445self.tb_writer.flush()
446
447logger.info("***** Running training *****")
448logger.info(" Num examples = %d", self.num_train_examples)
449logger.info(" Num Epochs = %d", epochs)
450logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
451logger.info(
452" Total train batch size (w. parallel, distributed & accumulation) = %d", self.total_train_batch_size
453)
454logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
455logger.info(" Steps per epoch = %d", self.steps_per_epoch)
456logger.info(" Total optimization steps = %d", t_total)
457
458self.train_loss = tf.keras.metrics.Sum()
459start_time = datetime.datetime.now()
460
461for epoch_iter in range(epochs_trained, int(epochs + 1)):
462# Reset the past mems state at the beginning of each epoch if necessary.
463if self.args.past_index >= 0:
464self._past = None
465
466for step, batch in enumerate(train_ds):
467self.global_step = iterations.numpy()
468self.epoch_logging = epoch_iter - 1 + (step + 1) / self.steps_per_epoch
469
470self.distributed_training_steps(batch)
471
472training_loss = self.train_loss.result() / ((step + 1) * self.total_train_batch_size)
473
474if self.args.debug:
475logs = {}
476logs["loss"] = training_loss.numpy()
477logs["epoch"] = self.epoch_logging
478
479self.log(logs)
480
481if self.global_step == 1 and self.args.debug:
482with self.tb_writer.as_default():
483tf.summary.trace_export(
484name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
485)
486
487if (
488self.global_step > 0
489and self.args.evaluate_during_training
490and self.global_step % self.args.eval_steps == 0
491):
492self.evaluate()
493
494if (self.global_step > 0 and self.global_step % self.args.logging_steps == 0) or (
495self.global_step == 1 and self.args.logging_first_step
496):
497logs = {}
498logs["loss"] = training_loss.numpy()
499logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
500logs["epoch"] = self.epoch_logging
501
502self.log(logs)
503
504if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
505ckpt_save_path = self.model.ckpt_manager.save()
506
507logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
508
509if self.global_step > 0 and self.global_step % self.steps_per_epoch == 0:
510break
511
512self.train_loss.reset_states()
513
514end_time = datetime.datetime.now()
515
516logger.info("Training took: {}".format(str(end_time - start_time)))
517
518if self.args.past_index and hasattr(self, "_past"):
519# Clean the state at the end of training
520delattr(self, "_past")
521
522def training_step(self, features, labels):
523"""
524Perform a training step on features and labels.
525
526Subclass and override to inject some custom behavior.
527"""
528per_example_loss, _ = self.run_model(features, labels, True)
529scaled_loss = per_example_loss / self.total_train_batch_size
530gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
531gradients = [
532g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
533]
534
535if self.args.gradient_accumulation_steps > 1:
536self.gradient_accumulator(gradients)
537
538self.train_loss.update_state(per_example_loss)
539
540if self.args.gradient_accumulation_steps == 1:
541return gradients
542
543def apply_gradients(self, features, labels):
544if self.args.gradient_accumulation_steps == 1:
545gradients = self.training_step(features, labels)
546
547self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
548else:
549for _ in tf.range(self.args.gradient_accumulation_steps):
550reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
551reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
552
553self.training_step(reduced_features, reduced_labels)
554
555features = tf.concat(
556[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
557)
558
559gradients = self.gradient_accumulator.gradients
560gradients = [
561(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients
562]
563
564self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
565self.gradient_accumulator.reset()
566
567@tf.function
568def distributed_training_steps(self, batch):
569with self.args.strategy.scope():
570self.args.strategy.run(self.apply_gradients, batch)
571
572def run_model(self, features, labels, training):
573"""
574Computes the loss of the given features and labels pair.
575
576Subclass and override this method if you want to inject some custom behavior.
577
578Args:
579features (:obj:`tf.Tensor`): A batch of input features.
580labels (:obj:`tf.Tensor`): A batch of labels.
581training (:obj:`bool`): Whether or not to run the model in training mode.
582
583Returns:
584A tuple of two :obj:`tf.Tensor`: The loss and logits.
585"""
586if hasattr(self, "_run_model"):
587warnings.warn(
588"The `_run_model` method is deprecated and won't be called in a future version, define `run_model` in your subclass.",
589FutureWarning,
590)
591return self._run_model(features, labels, training)
592
593if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
594features["mems"] = self._past
595
596if isinstance(labels, (dict)):
597outputs = self.model(features, training=training, **labels)[:2]
598else:
599outputs = self.model(features, labels=labels, training=training)[:2]
600
601loss, logits = outputs[:2]
602
603if self.args.past_index >= 0:
604self._past = outputs[self.args.past_index]
605
606return loss, logits
607
608def predict(self, test_dataset: tf.data.Dataset) -> PredictionOutput:
609"""
610Run prediction and returns predictions and potential metrics.
611
612Depending on the dataset and your use case, your test dataset may contain labels.
613In that case, this method will also return metrics, like in :obj:`evaluate()`.
614
615Args:
616test_dataset (:class:`~tf.data.Dataset`):
617Dataset to run the predictions on.
618Returns:
619`NamedTuple`:
620predictions (:obj:`np.ndarray`):
621The predictions on :obj:`test_dataset`.
622label_ids (:obj:`np.ndarray`, `optional`):
623The labels (if the dataset contained some).
624metrics (:obj:`Dict[str, float]`, `optional`):
625The potential dictionary of metrics (if the dataset contained labels).
626"""
627test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
628
629return self.prediction_loop(test_ds, steps, num_examples, description="Prediction")
630
631def save_model(self, output_dir: Optional[str] = None):
632"""
633Will save the model, so you can reload it using :obj:`from_pretrained()`.
634"""
635output_dir = output_dir if output_dir is not None else self.args.output_dir
636
637logger.info("Saving model in {}".format(output_dir))
638
639if not isinstance(self.model, TFPreTrainedModel):
640raise ValueError("Trainer.model appears to not be a PreTrainedModel")
641
642self.model.save_pretrained(output_dir)
643