pytorch-lightning

Форк
0
713 строк · 26.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Various hooks to be used in the Lightning code."""
15

16
from typing import Any, Dict, Optional
17

18
import torch
19
from torch import Tensor
20
from torch.optim.optimizer import Optimizer
21

22
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
23
from lightning.pytorch.utilities import move_data_to_device
24
from lightning.pytorch.utilities.exceptions import MisconfigurationException
25
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
26

27

28
class ModelHooks:
29
    """Hooks to be used in LightningModule."""
30

31
    def on_fit_start(self) -> None:
32
        """Called at the very beginning of fit.
33

34
        If on DDP it is called on every process
35

36
        """
37

38
    def on_fit_end(self) -> None:
39
        """Called at the very end of fit.
40

41
        If on DDP it is called on every process
42

43
        """
44

45
    def on_train_start(self) -> None:
46
        """Called at the beginning of training after sanity check."""
47

48
    def on_train_end(self) -> None:
49
        """Called at the end of training before logger experiment is closed."""
50

51
    def on_validation_start(self) -> None:
52
        """Called at the beginning of validation."""
53

54
    def on_validation_end(self) -> None:
55
        """Called at the end of validation."""
56

57
    def on_test_start(self) -> None:
58
        """Called at the beginning of testing."""
59

60
    def on_test_end(self) -> None:
61
        """Called at the end of testing."""
62

63
    def on_predict_start(self) -> None:
64
        """Called at the beginning of predicting."""
65

66
    def on_predict_end(self) -> None:
67
        """Called at the end of predicting."""
68

69
    def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
70
        """Called in the training loop before anything happens for that batch.
71

72
        If you return -1 here, you will skip training for the rest of the current epoch.
73

74
        Args:
75
            batch: The batched data as it is returned by the training DataLoader.
76
            batch_idx: the index of the batch
77

78
        """
79

80
    def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
81
        """Called in the training loop after the batch.
82

83
        Args:
84
            outputs: The outputs of training_step(x)
85
            batch: The batched data as it is returned by the training DataLoader.
86
            batch_idx: the index of the batch
87

88
        Note:
89
            The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
90
            loss returned from ``training_step``.
91

92
        """
93

94
    def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
95
        """Called in the validation loop before anything happens for that batch.
96

97
        Args:
98
            batch: The batched data as it is returned by the validation DataLoader.
99
            batch_idx: the index of the batch
100
            dataloader_idx: the index of the dataloader
101

102
        """
103

104
    def on_validation_batch_end(
105
        self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0
106
    ) -> None:
107
        """Called in the validation loop after the batch.
108

109
        Args:
110
            outputs: The outputs of validation_step(x)
111
            batch: The batched data as it is returned by the validation DataLoader.
112
            batch_idx: the index of the batch
113
            dataloader_idx: the index of the dataloader
114

115
        """
116

117
    def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
118
        """Called in the test loop before anything happens for that batch.
119

120
        Args:
121
            batch: The batched data as it is returned by the test DataLoader.
122
            batch_idx: the index of the batch
123
            dataloader_idx: the index of the dataloader
124

125
        """
126

127
    def on_test_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
128
        """Called in the test loop after the batch.
129

130
        Args:
131
            outputs: The outputs of test_step(x)
132
            batch: The batched data as it is returned by the test DataLoader.
133
            batch_idx: the index of the batch
134
            dataloader_idx: the index of the dataloader
135

136
        """
137

138
    def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
139
        """Called in the predict loop before anything happens for that batch.
140

141
        Args:
142
            batch: The batched data as it is returned by the test DataLoader.
143
            batch_idx: the index of the batch
144
            dataloader_idx: the index of the dataloader
145

146
        """
147

148
    def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
149
        """Called in the predict loop after the batch.
150

151
        Args:
152
            outputs: The outputs of predict_step(x)
153
            batch: The batched data as it is returned by the prediction DataLoader.
154
            batch_idx: the index of the batch
155
            dataloader_idx: the index of the dataloader
156

157
        """
158

159
    def on_validation_model_zero_grad(self) -> None:
160
        """Called by the training loop to release gradients before entering the validation loop."""
161
        zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
162
        self.zero_grad(**zero_grad_kwargs)
163

164
    def on_validation_model_eval(self) -> None:
165
        """Called when the validation loop starts.
166

167
        The validation loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
168
        to change the behavior. See also :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_validation_model_train`.
169

170
        """
171
        self.trainer.model.eval()
172

173
    def on_validation_model_train(self) -> None:
174
        """Called when the validation loop ends.
175

176
        The validation loop by default restores the `training` mode of the LightningModule to what it was before
177
        starting validation. Override this hook to change the behavior. See also
178
        :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_validation_model_eval`.
179

180
        """
181
        # The loop won't call this hook unless it is overridden. The line below is here in case the user calls super().
182
        self.trainer.model.train()
183

184
    def on_test_model_eval(self) -> None:
185
        """Called when the test loop starts.
186

187
        The test loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
188
        to change the behavior. See also :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_test_model_train`.
189

190
        """
191
        self.trainer.model.eval()
192

193
    def on_test_model_train(self) -> None:
194
        """Called when the test loop ends.
195

196
        The test loop by default restores the `training` mode of the LightningModule to what it was before
197
        starting testing. Override this hook to change the behavior. See also
198
        :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_test_model_eval`.
199

200
        """
201
        # The loop won't call this hook unless it is overridden. The line below is here in case the user calls super().
202
        self.trainer.model.train()
203

204
    def on_predict_model_eval(self) -> None:
205
        """Called when the predict loop starts.
206

207
        The predict loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
208
        to change the behavior.
209

210
        """
211
        self.trainer.model.eval()
212

213
    def on_train_epoch_start(self) -> None:
214
        """Called in the training loop at the very beginning of the epoch."""
215

216
    def on_train_epoch_end(self) -> None:
217
        """Called in the training loop at the very end of the epoch.
218

219
        To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
220
        :class:`~lightning.pytorch.LightningModule` and access them in this hook:
221

222
        .. code-block:: python
223

224
            class MyLightningModule(L.LightningModule):
225
                def __init__(self):
226
                    super().__init__()
227
                    self.training_step_outputs = []
228

229
                def training_step(self):
230
                    loss = ...
231
                    self.training_step_outputs.append(loss)
232
                    return loss
233

234
                def on_train_epoch_end(self):
235
                    # do something with all training_step outputs, for example:
236
                    epoch_mean = torch.stack(self.training_step_outputs).mean()
237
                    self.log("training_epoch_mean", epoch_mean)
238
                    # free up the memory
239
                    self.training_step_outputs.clear()
240

241
        """
242

243
    def on_validation_epoch_start(self) -> None:
244
        """Called in the validation loop at the very beginning of the epoch."""
245

246
    def on_validation_epoch_end(self) -> None:
247
        """Called in the validation loop at the very end of the epoch."""
248

249
    def on_test_epoch_start(self) -> None:
250
        """Called in the test loop at the very beginning of the epoch."""
251

252
    def on_test_epoch_end(self) -> None:
253
        """Called in the test loop at the very end of the epoch."""
254

255
    def on_predict_epoch_start(self) -> None:
256
        """Called at the beginning of predicting."""
257

258
    def on_predict_epoch_end(self) -> None:
259
        """Called at the end of predicting."""
260

261
    def on_before_zero_grad(self, optimizer: Optimizer) -> None:
262
        """Called after ``training_step()`` and before ``optimizer.zero_grad()``.
263

264
        Called in the training loop after taking an optimizer step and before zeroing grads.
265
        Good place to inspect weight information with weights updated.
266

267
        This is where it is called::
268

269
            for optimizer in optimizers:
270
                out = training_step(...)
271

272
                model.on_before_zero_grad(optimizer) # < ---- called here
273
                optimizer.zero_grad()
274

275
                backward()
276

277
        Args:
278
            optimizer: The optimizer for which grads should be zeroed.
279

280
        """
281

282
    def on_before_backward(self, loss: Tensor) -> None:
283
        """Called before ``loss.backward()``.
284

285
        Args:
286
            loss: Loss divided by number of batches for gradient accumulation and scaled if using AMP.
287

288
        """
289
        pass
290

291
    def on_after_backward(self) -> None:
292
        """Called after ``loss.backward()`` and before optimizers are stepped.
293

294
        Note:
295
            If using native AMP, the gradients will not be unscaled at this point.
296
            Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
297

298
        """
299

300
    def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
301
        """Called before ``optimizer.step()``.
302

303
        If using gradient accumulation, the hook is called once the gradients have been accumulated.
304
        See: :paramref:`~lightning.pytorch.trainer.trainer.Trainer.accumulate_grad_batches`.
305

306
        If using AMP, the loss will be unscaled before calling this hook.
307
        See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
308
        for more information on the scaling of gradients.
309

310
        If clipping gradients, the gradients will not have been clipped yet.
311

312
        Args:
313
            optimizer: Current optimizer being used.
314

315
        Example::
316

317
            def on_before_optimizer_step(self, optimizer):
318
                # example to inspect gradient information in tensorboard
319
                if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
320
                    for k, v in self.named_parameters():
321
                        self.logger.experiment.add_histogram(
322
                            tag=k, values=v.grad, global_step=self.trainer.global_step
323
                        )
324

325
        """
326

327
    def configure_sharded_model(self) -> None:
328
        """Deprecated.
329

330
        Use :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` instead.
331

332
        """
333

334
    def configure_model(self) -> None:
335
        """Hook to create modules in a strategy and precision aware context.
336

337
        This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we'd like to shard
338
        the model instantly to save memory and initialization time.
339
        For non-sharded strategies, you can choose to override this hook or to initialize your model under the
340
        :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager.
341

342
        This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
343
        implementation of this hook is **idempotent**, i.e., after the first time the hook is called, subsequent calls
344
        to it should be a no-op.
345

346
        """
347

348

349
class DataHooks:
350
    """Hooks to be used for data related stuff."""
351

352
    def __init__(self) -> None:
353
        """
354
        Attributes:
355
            prepare_data_per_node:
356
                If True, each LOCAL_RANK=0 will call prepare data.
357
                Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
358
            allow_zero_length_dataloader_with_multiple_devices:
359
                If True, dataloader with zero length within local rank is allowed.
360
                Default value is False.
361
        """
362
        super().__init__()
363
        self.prepare_data_per_node: bool = True
364
        self.allow_zero_length_dataloader_with_multiple_devices: bool = False
365

366
    def prepare_data(self) -> None:
367
        """Use this to download and prepare data. Downloading and saving data with multiple processes (distributed
368
        settings) will result in corrupted data. Lightning ensures this method is called only within a single process,
369
        so you can safely add your downloading logic within.
370

371
        .. warning:: DO NOT set state to the model (use ``setup`` instead)
372
            since this is NOT called on every device
373

374
        Example::
375

376
            def prepare_data(self):
377
                # good
378
                download_data()
379
                tokenize()
380
                etc()
381

382
                # bad
383
                self.split = data_split
384
                self.some_state = some_other_state()
385

386
        In a distributed environment, ``prepare_data`` can be called in two ways
387
        (using :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`)
388

389
        1. Once per node. This is the default and is only called on LOCAL_RANK=0.
390
        2. Once in total. Only called on GLOBAL_RANK=0.
391

392
        Example::
393

394
            # DEFAULT
395
            # called once per node on LOCAL_RANK=0 of that node
396
            class LitDataModule(LightningDataModule):
397
                def __init__(self):
398
                    super().__init__()
399
                    self.prepare_data_per_node = True
400

401

402
            # call on GLOBAL_RANK=0 (great for shared file systems)
403
            class LitDataModule(LightningDataModule):
404
                def __init__(self):
405
                    super().__init__()
406
                    self.prepare_data_per_node = False
407

408
        This is called before requesting the dataloaders:
409

410
        .. code-block:: python
411

412
            model.prepare_data()
413
            initialize_distributed()
414
            model.setup(stage)
415
            model.train_dataloader()
416
            model.val_dataloader()
417
            model.test_dataloader()
418
            model.predict_dataloader()
419

420
        """
421

422
    def setup(self, stage: str) -> None:
423
        """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you
424
        need to build models dynamically or adjust something about them. This hook is called on every process when
425
        using DDP.
426

427
        Args:
428
            stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
429

430
        Example::
431

432
            class LitModel(...):
433
                def __init__(self):
434
                    self.l1 = None
435

436
                def prepare_data(self):
437
                    download_data()
438
                    tokenize()
439

440
                    # don't do this
441
                    self.something = else
442

443
                def setup(self, stage):
444
                    data = load_data(...)
445
                    self.l1 = nn.Linear(28, data.num_classes)
446

447
        """
448

449
    def teardown(self, stage: str) -> None:
450
        """Called at the end of fit (train + validate), validate, test, or predict.
451

452
        Args:
453
            stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
454

455
        """
456

457
    def train_dataloader(self) -> TRAIN_DATALOADERS:
458
        """An iterable or collection of iterables specifying training samples.
459

460
        For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
461

462
        The dataloader you return will not be reloaded unless you set
463
        :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to
464
        a positive integer.
465

466
        For data processing use the following pattern:
467

468
            - download in :meth:`prepare_data`
469
            - process and split in :meth:`setup`
470

471
        However, the above are only necessary for distributed processing.
472

473
        .. warning:: do not assign state in prepare_data
474

475
        - :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
476
        - :meth:`prepare_data`
477
        - :meth:`setup`
478

479
        Note:
480
            Lightning tries to add the correct sampler for distributed and arbitrary hardware.
481
            There is no need to set it yourself.
482

483
        """
484
        raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
485

486
    def test_dataloader(self) -> EVAL_DATALOADERS:
487
        r"""An iterable or collection of iterables specifying test samples.
488

489
        For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
490

491
        For data processing use the following pattern:
492

493
            - download in :meth:`prepare_data`
494
            - process and split in :meth:`setup`
495

496
        However, the above are only necessary for distributed processing.
497

498
        .. warning:: do not assign state in prepare_data
499

500

501
        - :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`
502
        - :meth:`prepare_data`
503
        - :meth:`setup`
504

505
        Note:
506
            Lightning tries to add the correct sampler for distributed and arbitrary hardware.
507
            There is no need to set it yourself.
508

509
        Note:
510
            If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
511
            this method.
512

513
        """
514
        raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer")
515

516
    def val_dataloader(self) -> EVAL_DATALOADERS:
517
        r"""An iterable or collection of iterables specifying validation samples.
518

519
        For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
520

521
        The dataloader you return will not be reloaded unless you set
522
        :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to
523
        a positive integer.
524

525
        It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
526

527
        - :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
528
        - :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`
529
        - :meth:`prepare_data`
530
        - :meth:`setup`
531

532
        Note:
533
            Lightning tries to add the correct sampler for distributed and arbitrary hardware
534
            There is no need to set it yourself.
535

536
        Note:
537
            If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
538
            implement this method.
539

540
        """
541
        raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer")
542

543
    def predict_dataloader(self) -> EVAL_DATALOADERS:
544
        r"""An iterable or collection of iterables specifying prediction samples.
545

546
        For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
547

548
        It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
549

550
        - :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`
551
        - :meth:`prepare_data`
552
        - :meth:`setup`
553

554
        Note:
555
            Lightning tries to add the correct sampler for distributed and arbitrary hardware
556
            There is no need to set it yourself.
557

558
        Return:
559
            A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.
560

561
        """
562
        raise MisconfigurationException(
563
            "`predict_dataloader` must be implemented to be used with the Lightning Trainer"
564
        )
565

566
    def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
567
        """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data
568
        structure.
569

570
        The data types listed below (and any arbitrary nesting of them) are supported out of the box:
571

572
        - :class:`torch.Tensor` or anything that implements `.to(...)`
573
        - :class:`list`
574
        - :class:`dict`
575
        - :class:`tuple`
576

577
        For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
578

579
        Note:
580
            This hook should only transfer the data and not modify it, nor should it move the data to
581
            any other device than the one passed in as argument (unless you know what you are doing).
582
            To check the current state of execution of this hook you can use
583
            ``self.trainer.training/testing/validating/predicting`` so that you can
584
            add different logic as per your requirement.
585

586
        Args:
587
            batch: A batch of data that needs to be transferred to a new device.
588
            device: The target device as defined in PyTorch.
589
            dataloader_idx: The index of the dataloader to which the batch belongs.
590

591
        Returns:
592
            A reference to the data on the new device.
593

594
        Example::
595

596
            def transfer_batch_to_device(self, batch, device, dataloader_idx):
597
                if isinstance(batch, CustomBatch):
598
                    # move all tensors in your custom data structure to the device
599
                    batch.samples = batch.samples.to(device)
600
                    batch.targets = batch.targets.to(device)
601
                elif dataloader_idx == 0:
602
                    # skip device transfer for the first dataloader or anything you wish
603
                    pass
604
                else:
605
                    batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
606
                return batch
607

608
        See Also:
609
            - :meth:`move_data_to_device`
610
            - :meth:`apply_to_collection`
611

612
        """
613
        return move_data_to_device(batch, device)
614

615
    def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
616
        """Override to alter or apply batch augmentations to your batch before it is transferred to the device.
617

618
        Note:
619
            To check the current state of execution of this hook you can use
620
            ``self.trainer.training/testing/validating/predicting`` so that you can
621
            add different logic as per your requirement.
622

623
        Args:
624
            batch: A batch of data that needs to be altered or augmented.
625
            dataloader_idx: The index of the dataloader to which the batch belongs.
626

627
        Returns:
628
            A batch of data
629

630
        Example::
631

632
            def on_before_batch_transfer(self, batch, dataloader_idx):
633
                batch['x'] = transforms(batch['x'])
634
                return batch
635

636
        See Also:
637
            - :meth:`on_after_batch_transfer`
638
            - :meth:`transfer_batch_to_device`
639

640
        """
641
        return batch
642

643
    def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
644
        """Override to alter or apply batch augmentations to your batch after it is transferred to the device.
645

646
        Note:
647
            To check the current state of execution of this hook you can use
648
            ``self.trainer.training/testing/validating/predicting`` so that you can
649
            add different logic as per your requirement.
650

651
        Args:
652
            batch: A batch of data that needs to be altered or augmented.
653
            dataloader_idx: The index of the dataloader to which the batch belongs.
654

655
        Returns:
656
            A batch of data
657

658
        Example::
659

660
            def on_after_batch_transfer(self, batch, dataloader_idx):
661
                batch['x'] = gpu_transforms(batch['x'])
662
                return batch
663

664
        See Also:
665
            - :meth:`on_before_batch_transfer`
666
            - :meth:`transfer_batch_to_device`
667

668
        """
669
        return batch
670

671

672
class CheckpointHooks:
673
    """Hooks to be used with Checkpointing."""
674

675
    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
676
        r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is
677
        your chance to restore this.
678

679
        Args:
680
            checkpoint: Loaded checkpoint
681

682
        Example::
683

684
            def on_load_checkpoint(self, checkpoint):
685
                # 99% of the time you don't need to implement this method
686
                self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
687

688
        Note:
689
            Lightning auto-restores global step, epoch, and train state including amp scaling.
690
            There is no need for you to restore anything regarding training.
691

692
        """
693

694
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
695
        r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to
696
        save.
697

698
        Args:
699
            checkpoint: The full checkpoint dictionary before it gets dumped to a file.
700
                Implementations of this hook can insert additional data into this dictionary.
701

702
        Example::
703

704
            def on_save_checkpoint(self, checkpoint):
705
                # 99% of use cases you don't need to implement this method
706
                checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
707

708
        Note:
709
            Lightning saves all aspects of training (epoch, global step, etc...)
710
            including amp scaling.
711
            There is no need for you to store anything about training.
712

713
        """
714

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.