CSS-LM

Форк
0
/
trainer_tf.py 
642 строки · 25.9 Кб
1
"""Tensorflow trainer class."""
2

3
import datetime
4
import logging
5
import math
6
import os
7
import sys
8
import warnings
9
from typing import Callable, Dict, Optional, Tuple
10

11
import numpy as np
12
import tensorflow as tf
13
from packaging.version import parse
14

15
from .modeling_tf_utils import TFPreTrainedModel
16
from .optimization_tf import GradientAccumulator, create_optimizer
17
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available, set_seed
18
from .training_args_tf import TFTrainingArguments
19

20

21
if is_wandb_available():
22
    import wandb
23

24

25
logger = logging.getLogger(__name__)
26

27

28
if parse(tf.__version__).release < (2, 2, 0):
29
    logger.info(
30
        "You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
31
            tf.__version__
32
        )
33
    )
34
    sys.exit(1)
35

36

37
class TFTrainer:
38
    """
39
    TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
40
    optimized for 🤗 Transformers.
41

42
    Args:
43
        model (:class:`~transformers.TFPreTrainedModel`):
44
            The model to train, evaluate or use for predictions.
45
        args (:class:`~transformers.TFTrainingArguments`):
46
            The arguments to tweak training.
47
        train_dataset (:class:`~tf.data.Dataset`, `optional`):
48
            The dataset to use for training.
49
        eval_dataset (:class:`~tf.data.Dataset`, `optional`):
50
            The dataset to use for evaluation.
51
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
52
            The function that will be used to compute metrics at evaluation. Must take a
53
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
54
        prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
55
            When performing evaluation and predictions, only returns the loss.
56
        tb_writer (:obj:`tf.summary.SummaryWriter`, `optional`):
57
            Object to write to TensorBoard.
58
        optimizers (:obj:`Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]`, `optional`):
59
            A tuple containing the optimizer and the scheduler to use. The optimizer default to an instance of
60
            :class:`tf.keras.optimizers.Adam` if :obj:`args.weight_decay_rate` is 0 else an instance of
61
            :class:`~transformers.AdamWeightDecay`. The scheduler will default to an instance of
62
            :class:`tf.keras.optimizers.schedules.PolynomialDecay` if :obj:`args.num_warmup_steps` is 0 else
63
            an instance of :class:`~transformers.WarmUp`.
64
    """
65

66
    def __init__(
67
        self,
68
        model: TFPreTrainedModel,
69
        args: TFTrainingArguments,
70
        train_dataset: Optional[tf.data.Dataset] = None,
71
        eval_dataset: Optional[tf.data.Dataset] = None,
72
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
73
        prediction_loss_only=False,
74
        tb_writer: Optional[tf.summary.SummaryWriter] = None,
75
        optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (
76
            None,
77
            None,
78
        ),
79
    ):
80
        self.model = model
81
        self.args = args
82
        self.train_dataset = train_dataset
83
        self.eval_dataset = eval_dataset
84
        self.compute_metrics = compute_metrics
85
        self.prediction_loss_only = prediction_loss_only
86
        self.optimizer, self.lr_scheduler = optimizers
87
        self.gradient_accumulator = GradientAccumulator()
88
        self.global_step = 0
89
        self.epoch_logging = 0
90

91
        if tb_writer is not None:
92
            self.tb_writer = tb_writer
93
        else:
94
            self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
95

96
        if is_wandb_available():
97
            self.setup_wandb()
98
        elif os.environ.get("WANDB_DISABLED") != "true":
99
            logger.info(
100
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
101
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
102
            )
103

104
        set_seed(self.args.seed)
105

106
    def get_train_tfdataset(self) -> tf.data.Dataset:
107
        """
108
        Returns the training :class:`~tf.data.Dataset`.
109

110
        Subclass and override this method if you want to inject some custom behavior.
111
        """
112
        if self.train_dataset is None:
113
            raise ValueError("Trainer: training requires a train_dataset.")
114

115
        self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
116
        self.num_train_examples = tf.data.experimental.cardinality(self.train_dataset).numpy()
117

118
        if self.num_train_examples < 0:
119
            raise ValueError("The training dataset must have an asserted cardinality")
120

121
        ds = (
122
            self.train_dataset.repeat()
123
            .shuffle(self.num_train_examples, seed=self.args.seed)
124
            .batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
125
            .prefetch(tf.data.experimental.AUTOTUNE)
126
        )
127

128
        return self.args.strategy.experimental_distribute_dataset(ds)
129

130
    def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
131
        """
132
        Returns the evaluation :class:`~tf.data.Dataset`.
133

134
        Args:
135
            eval_dataset (:class:`~tf.data.Dataset`, `optional`):
136
                If provided, will override `self.eval_dataset`.
137

138
        Subclass and override this method if you want to inject some custom behavior.
139
        """
140
        if eval_dataset is None and self.eval_dataset is None:
141
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
142

143
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
144
        num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
145

146
        if num_examples < 0:
147
            raise ValueError("The training dataset must have an asserted cardinality")
148

149
        approx = math.floor if self.args.dataloader_drop_last else math.ceil
150
        steps = approx(num_examples / self.args.eval_batch_size)
151
        ds = (
152
            eval_dataset.repeat()
153
            .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
154
            .prefetch(tf.data.experimental.AUTOTUNE)
155
        )
156

157
        return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
158

159
    def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
160
        """
161
        Returns a test :class:`~tf.data.Dataset`.
162

163
        Args:
164
            test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
165

166
        Subclass and override this method if you want to inject some custom behavior.
167
        """
168

169
        num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
170

171
        if num_examples < 0:
172
            raise ValueError("The training dataset must have an asserted cardinality")
173

174
        approx = math.floor if self.args.dataloader_drop_last else math.ceil
175
        steps = approx(num_examples / self.args.eval_batch_size)
176
        ds = (
177
            test_dataset.repeat()
178
            .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
179
            .prefetch(tf.data.experimental.AUTOTUNE)
180
        )
181

182
        return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
183

184
    def create_optimizer_and_scheduler(self, num_training_steps: int):
185
        """
186
        Setup the optimizer and the learning rate scheduler.
187

188
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
189
        TFTrainer's init through :obj:`optimizers`, or subclass and override this method.
190
        """
191
        if not self.optimizer and not self.lr_scheduler:
192
            self.optimizer, self.lr_scheduler = create_optimizer(
193
                self.args.learning_rate,
194
                num_training_steps,
195
                self.args.warmup_steps,
196
                adam_beta1=self.args.adam_beta1,
197
                adam_beta2=self.args.adam_beta2,
198
                adam_epsilon=self.args.adam_epsilon,
199
                weight_decay_rate=self.args.weight_decay,
200
            )
201

202
    def setup_wandb(self):
203
        """
204
        Setup the optional Weights & Biases (`wandb`) integration.
205

206
        One can subclass and override this method to customize the setup if needed. Find more information
207
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
208

209
        Environment:
210
            WANDB_PROJECT:
211
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
212
            WANDB_DISABLED:
213
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
214
        """
215
        if hasattr(self, "_setup_wandb"):
216
            warnings.warn(
217
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
218
                FutureWarning,
219
            )
220
            return self._setup_wandb()
221

222
        logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
223
        wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
224

225
    def prediction_loop(
226
        self,
227
        dataset: tf.data.Dataset,
228
        steps: int,
229
        num_examples: int,
230
        description: str,
231
        prediction_loss_only: Optional[bool] = None,
232
    ) -> PredictionOutput:
233
        """
234
        Prediction/evaluation loop, shared by :func:`~transformers.TFTrainer.evaluate` and
235
        :func:`~transformers.TFTrainer.predict`.
236

237
        Works both with or without labels.
238
        """
239
        if hasattr(self, "_prediction_loop"):
240
            warnings.warn(
241
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
242
                FutureWarning,
243
            )
244
            return self._prediction_loop(
245
                dataset, steps, num_examples, description, prediction_loss_only=prediction_loss_only
246
            )
247

248
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
249

250
        logger.info("***** Running %s *****", description)
251
        logger.info("  Num examples = %d", num_examples)
252
        logger.info("  Batch size = %d", self.args.eval_batch_size)
253

254
        label_ids: np.ndarray = None
255
        preds: np.ndarray = None
256
        self.eval_loss = tf.keras.metrics.Sum()
257

258
        # Reset the past mems state at the beginning of the evaluation if necessary.
259
        if self.args.past_index >= 0:
260
            self._past = None
261

262
        for step, batch in enumerate(dataset):
263
            logits = self.distributed_prediction_steps(batch)
264
            _, labels = batch
265

266
            if not prediction_loss_only:
267
                if isinstance(logits, tuple):
268
                    logits = logits[0]
269

270
                if isinstance(labels, tuple):
271
                    labels = labels[0]
272

273
                if self.args.n_replicas > 1:
274
                    for val in logits.values:
275
                        if preds is None:
276
                            preds = val.numpy()
277
                        else:
278
                            preds = np.append(preds, val.numpy(), axis=0)
279

280
                    for val in labels.values:
281
                        if label_ids is None:
282
                            label_ids = val.numpy()
283
                        else:
284
                            label_ids = np.append(label_ids, val.numpy(), axis=0)
285
                else:
286
                    if preds is None:
287
                        preds = logits.numpy()
288
                    else:
289
                        preds = np.append(preds, logits.numpy(), axis=0)
290

291
                    if label_ids is None:
292
                        label_ids = labels.numpy()
293
                    else:
294
                        label_ids = np.append(label_ids, labels.numpy(), axis=0)
295

296
                if step == steps:
297
                    break
298

299
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
300
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
301
        else:
302
            metrics = {}
303

304
        metrics["eval_loss"] = self.eval_loss.result().numpy() / (steps * self.args.eval_batch_size)
305

306
        for key in list(metrics.keys()):
307
            if not key.startswith("eval_"):
308
                metrics[f"eval_{key}"] = metrics.pop(key)
309

310
        if self.args.past_index and hasattr(self, "_past"):
311
            # Clean the state at the end of training
312
            delattr(self, "_past")
313

314
        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
315

316
    def log(self, logs: Dict[str, float]) -> None:
317
        """
318
        Log :obj:`logs` on the various objects watching training.
319

320
        Subclass and override this method to inject custom behavior.
321

322
        Args:
323
            logs (:obj:`Dict[str, float]`):
324
                The values to log.
325
        """
326
        if hasattr(self, "_log"):
327
            warnings.warn(
328
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
329
                FutureWarning,
330
            )
331
            return self._log(logs)
332
        logs["epoch"] = self.epoch_logging
333

334
        if self.tb_writer:
335
            with self.tb_writer.as_default():
336
                for k, v in logs.items():
337
                    tf.summary.scalar(k, v, step=self.global_step)
338
            self.tb_writer.flush()
339

340
        if is_wandb_available():
341
            wandb.log(logs, step=self.global_step)
342

343
        output = {**logs, **{"step": self.global_step}}
344

345
        logger.info(output)
346

347
    def evaluate(self, eval_dataset: Optional[tf.data.Dataset] = None) -> Dict[str, float]:
348
        """
349
        Run evaluation and returns metrics.
350

351
        The calling script will be responsible for providing a method to compute metrics, as they are
352
        task-dependent (pass it to the init :obj:`compute_metrics` argument).
353

354
        Args:
355
            eval_dataset (:class:`~tf.data.Dataset`, `optional`):
356
                Pass a dataset if you wish to override :obj:`self.eval_dataset`.
357

358
        Returns:
359
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
360
        """
361
        eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
362

363
        output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
364
        logs = {**output.metrics}
365
        logs["epoch"] = self.epoch_logging
366

367
        self.log(logs)
368

369
        return output.metrics
370

371
    def prediction_step(self, features: tf.Tensor, labels: tf.Tensor) -> tf.Tensor:
372
        """
373
        Compute the prediction on features and update the loss with labels.
374

375
        Subclass and override to inject some custom behavior.
376
        """
377
        per_example_loss, logits = self.run_model(features, labels, False)
378

379
        self.eval_loss.update_state(per_example_loss)
380

381
        return logits
382

383
    @tf.function
384
    def distributed_prediction_steps(self, batch):
385
        logits = self.args.strategy.run(self.prediction_step, batch)
386

387
        return logits
388

389
    def train(self) -> None:
390
        """
391
        Train method to train the model.
392
        """
393
        train_ds = self.get_train_tfdataset()
394

395
        if self.args.debug:
396
            tf.summary.trace_on(graph=True, profiler=True)
397

398
        self.gradient_accumulator.reset()
399

400
        if self.args.max_steps > 0:
401
            t_total = self.args.max_steps
402
            self.steps_per_epoch = self.args.max_steps
403
        else:
404
            approx = math.floor if self.args.dataloader_drop_last else math.ceil
405
            self.steps_per_epoch = approx(self.num_train_examples / self.total_train_batch_size)
406
            t_total = self.steps_per_epoch * self.args.num_train_epochs
407

408
        with self.args.strategy.scope():
409
            self.create_optimizer_and_scheduler(num_training_steps=t_total)
410
            iterations = self.optimizer.iterations
411
            self.global_step = iterations.numpy()
412
            folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
413
            ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
414
            self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
415

416
            if self.model.ckpt_manager.latest_checkpoint:
417
                epochs_trained = self.global_step // (self.num_train_examples // self.args.gradient_accumulation_steps)
418
                steps_trained_in_current_epoch = self.global_step % (
419
                    self.num_train_examples // self.args.gradient_accumulation_steps
420
                )
421

422
                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
423
                logger.info("  Continuing training from epoch %d", epochs_trained)
424
                logger.info("  Continuing training from global step %d", self.global_step)
425
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
426
                logger.info(
427
                    "Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
428
                )
429

430
                ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
431
            else:
432
                epochs_trained = 1
433

434
            tf.summary.experimental.set_step(iterations)
435

436
            epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
437

438
            if self.args.fp16:
439
                policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
440
                tf.keras.mixed_precision.experimental.set_policy(policy)
441

442
            with self.tb_writer.as_default():
443
                tf.summary.text("args", self.args.to_json_string())
444

445
            self.tb_writer.flush()
446

447
            logger.info("***** Running training *****")
448
            logger.info("  Num examples = %d", self.num_train_examples)
449
            logger.info("  Num Epochs = %d", epochs)
450
            logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
451
            logger.info(
452
                "  Total train batch size (w. parallel, distributed & accumulation) = %d", self.total_train_batch_size
453
            )
454
            logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
455
            logger.info("  Steps per epoch = %d", self.steps_per_epoch)
456
            logger.info("  Total optimization steps = %d", t_total)
457

458
            self.train_loss = tf.keras.metrics.Sum()
459
            start_time = datetime.datetime.now()
460

461
            for epoch_iter in range(epochs_trained, int(epochs + 1)):
462
                # Reset the past mems state at the beginning of each epoch if necessary.
463
                if self.args.past_index >= 0:
464
                    self._past = None
465

466
                for step, batch in enumerate(train_ds):
467
                    self.global_step = iterations.numpy()
468
                    self.epoch_logging = epoch_iter - 1 + (step + 1) / self.steps_per_epoch
469

470
                    self.distributed_training_steps(batch)
471

472
                    training_loss = self.train_loss.result() / ((step + 1) * self.total_train_batch_size)
473

474
                    if self.args.debug:
475
                        logs = {}
476
                        logs["loss"] = training_loss.numpy()
477
                        logs["epoch"] = self.epoch_logging
478

479
                        self.log(logs)
480

481
                    if self.global_step == 1 and self.args.debug:
482
                        with self.tb_writer.as_default():
483
                            tf.summary.trace_export(
484
                                name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
485
                            )
486

487
                    if (
488
                        self.global_step > 0
489
                        and self.args.evaluate_during_training
490
                        and self.global_step % self.args.eval_steps == 0
491
                    ):
492
                        self.evaluate()
493

494
                    if (self.global_step > 0 and self.global_step % self.args.logging_steps == 0) or (
495
                        self.global_step == 1 and self.args.logging_first_step
496
                    ):
497
                        logs = {}
498
                        logs["loss"] = training_loss.numpy()
499
                        logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
500
                        logs["epoch"] = self.epoch_logging
501

502
                        self.log(logs)
503

504
                    if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
505
                        ckpt_save_path = self.model.ckpt_manager.save()
506

507
                        logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
508

509
                    if self.global_step > 0 and self.global_step % self.steps_per_epoch == 0:
510
                        break
511

512
                self.train_loss.reset_states()
513

514
            end_time = datetime.datetime.now()
515

516
            logger.info("Training took: {}".format(str(end_time - start_time)))
517

518
        if self.args.past_index and hasattr(self, "_past"):
519
            # Clean the state at the end of training
520
            delattr(self, "_past")
521

522
    def training_step(self, features, labels):
523
        """
524
        Perform a training step on features and labels.
525

526
        Subclass and override to inject some custom behavior.
527
        """
528
        per_example_loss, _ = self.run_model(features, labels, True)
529
        scaled_loss = per_example_loss / self.total_train_batch_size
530
        gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
531
        gradients = [
532
            g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
533
        ]
534

535
        if self.args.gradient_accumulation_steps > 1:
536
            self.gradient_accumulator(gradients)
537

538
        self.train_loss.update_state(per_example_loss)
539

540
        if self.args.gradient_accumulation_steps == 1:
541
            return gradients
542

543
    def apply_gradients(self, features, labels):
544
        if self.args.gradient_accumulation_steps == 1:
545
            gradients = self.training_step(features, labels)
546

547
            self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
548
        else:
549
            for _ in tf.range(self.args.gradient_accumulation_steps):
550
                reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
551
                reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
552

553
                self.training_step(reduced_features, reduced_labels)
554

555
                features = tf.concat(
556
                    [features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
557
                )
558

559
            gradients = self.gradient_accumulator.gradients
560
            gradients = [
561
                (tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients
562
            ]
563

564
            self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
565
            self.gradient_accumulator.reset()
566

567
    @tf.function
568
    def distributed_training_steps(self, batch):
569
        with self.args.strategy.scope():
570
            self.args.strategy.run(self.apply_gradients, batch)
571

572
    def run_model(self, features, labels, training):
573
        """
574
        Computes the loss of the given features and labels pair.
575

576
        Subclass and override this method if you want to inject some custom behavior.
577

578
        Args:
579
            features (:obj:`tf.Tensor`): A batch of input features.
580
            labels (:obj:`tf.Tensor`): A batch of labels.
581
            training (:obj:`bool`): Whether or not to run the model in training mode.
582

583
        Returns:
584
            A tuple of two :obj:`tf.Tensor`: The loss and logits.
585
        """
586
        if hasattr(self, "_run_model"):
587
            warnings.warn(
588
                "The `_run_model` method is deprecated and won't be called in a future version, define `run_model` in your subclass.",
589
                FutureWarning,
590
            )
591
            return self._run_model(features, labels, training)
592

593
        if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
594
            features["mems"] = self._past
595

596
        if isinstance(labels, (dict)):
597
            outputs = self.model(features, training=training, **labels)[:2]
598
        else:
599
            outputs = self.model(features, labels=labels, training=training)[:2]
600

601
        loss, logits = outputs[:2]
602

603
        if self.args.past_index >= 0:
604
            self._past = outputs[self.args.past_index]
605

606
        return loss, logits
607

608
    def predict(self, test_dataset: tf.data.Dataset) -> PredictionOutput:
609
        """
610
        Run prediction and returns predictions and potential metrics.
611

612
        Depending on the dataset and your use case, your test dataset may contain labels.
613
        In that case, this method will also return metrics, like in :obj:`evaluate()`.
614

615
        Args:
616
            test_dataset (:class:`~tf.data.Dataset`):
617
                Dataset to run the predictions on.
618
        Returns:
619
            `NamedTuple`:
620
            predictions (:obj:`np.ndarray`):
621
                The predictions on :obj:`test_dataset`.
622
            label_ids (:obj:`np.ndarray`, `optional`):
623
                The labels (if the dataset contained some).
624
            metrics (:obj:`Dict[str, float]`, `optional`):
625
                The potential dictionary of metrics (if the dataset contained labels).
626
        """
627
        test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
628

629
        return self.prediction_loop(test_ds, steps, num_examples, description="Prediction")
630

631
    def save_model(self, output_dir: Optional[str] = None):
632
        """
633
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
634
        """
635
        output_dir = output_dir if output_dir is not None else self.args.output_dir
636

637
        logger.info("Saving model in {}".format(output_dir))
638

639
        if not isinstance(self.model, TFPreTrainedModel):
640
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
641

642
        self.model.save_pretrained(output_dir)
643

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.