pytorch-lightning
713 строк · 26.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Various hooks to be used in the Lightning code."""
15
16from typing import Any, Dict, Optional
17
18import torch
19from torch import Tensor
20from torch.optim.optimizer import Optimizer
21
22from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
23from lightning.pytorch.utilities import move_data_to_device
24from lightning.pytorch.utilities.exceptions import MisconfigurationException
25from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
26
27
28class ModelHooks:
29"""Hooks to be used in LightningModule."""
30
31def on_fit_start(self) -> None:
32"""Called at the very beginning of fit.
33
34If on DDP it is called on every process
35
36"""
37
38def on_fit_end(self) -> None:
39"""Called at the very end of fit.
40
41If on DDP it is called on every process
42
43"""
44
45def on_train_start(self) -> None:
46"""Called at the beginning of training after sanity check."""
47
48def on_train_end(self) -> None:
49"""Called at the end of training before logger experiment is closed."""
50
51def on_validation_start(self) -> None:
52"""Called at the beginning of validation."""
53
54def on_validation_end(self) -> None:
55"""Called at the end of validation."""
56
57def on_test_start(self) -> None:
58"""Called at the beginning of testing."""
59
60def on_test_end(self) -> None:
61"""Called at the end of testing."""
62
63def on_predict_start(self) -> None:
64"""Called at the beginning of predicting."""
65
66def on_predict_end(self) -> None:
67"""Called at the end of predicting."""
68
69def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
70"""Called in the training loop before anything happens for that batch.
71
72If you return -1 here, you will skip training for the rest of the current epoch.
73
74Args:
75batch: The batched data as it is returned by the training DataLoader.
76batch_idx: the index of the batch
77
78"""
79
80def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
81"""Called in the training loop after the batch.
82
83Args:
84outputs: The outputs of training_step(x)
85batch: The batched data as it is returned by the training DataLoader.
86batch_idx: the index of the batch
87
88Note:
89The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
90loss returned from ``training_step``.
91
92"""
93
94def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
95"""Called in the validation loop before anything happens for that batch.
96
97Args:
98batch: The batched data as it is returned by the validation DataLoader.
99batch_idx: the index of the batch
100dataloader_idx: the index of the dataloader
101
102"""
103
104def on_validation_batch_end(
105self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0
106) -> None:
107"""Called in the validation loop after the batch.
108
109Args:
110outputs: The outputs of validation_step(x)
111batch: The batched data as it is returned by the validation DataLoader.
112batch_idx: the index of the batch
113dataloader_idx: the index of the dataloader
114
115"""
116
117def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
118"""Called in the test loop before anything happens for that batch.
119
120Args:
121batch: The batched data as it is returned by the test DataLoader.
122batch_idx: the index of the batch
123dataloader_idx: the index of the dataloader
124
125"""
126
127def on_test_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
128"""Called in the test loop after the batch.
129
130Args:
131outputs: The outputs of test_step(x)
132batch: The batched data as it is returned by the test DataLoader.
133batch_idx: the index of the batch
134dataloader_idx: the index of the dataloader
135
136"""
137
138def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
139"""Called in the predict loop before anything happens for that batch.
140
141Args:
142batch: The batched data as it is returned by the test DataLoader.
143batch_idx: the index of the batch
144dataloader_idx: the index of the dataloader
145
146"""
147
148def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
149"""Called in the predict loop after the batch.
150
151Args:
152outputs: The outputs of predict_step(x)
153batch: The batched data as it is returned by the prediction DataLoader.
154batch_idx: the index of the batch
155dataloader_idx: the index of the dataloader
156
157"""
158
159def on_validation_model_zero_grad(self) -> None:
160"""Called by the training loop to release gradients before entering the validation loop."""
161zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
162self.zero_grad(**zero_grad_kwargs)
163
164def on_validation_model_eval(self) -> None:
165"""Called when the validation loop starts.
166
167The validation loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
168to change the behavior. See also :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_validation_model_train`.
169
170"""
171self.trainer.model.eval()
172
173def on_validation_model_train(self) -> None:
174"""Called when the validation loop ends.
175
176The validation loop by default restores the `training` mode of the LightningModule to what it was before
177starting validation. Override this hook to change the behavior. See also
178:meth:`~lightning.pytorch.core.hooks.ModelHooks.on_validation_model_eval`.
179
180"""
181# The loop won't call this hook unless it is overridden. The line below is here in case the user calls super().
182self.trainer.model.train()
183
184def on_test_model_eval(self) -> None:
185"""Called when the test loop starts.
186
187The test loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
188to change the behavior. See also :meth:`~lightning.pytorch.core.hooks.ModelHooks.on_test_model_train`.
189
190"""
191self.trainer.model.eval()
192
193def on_test_model_train(self) -> None:
194"""Called when the test loop ends.
195
196The test loop by default restores the `training` mode of the LightningModule to what it was before
197starting testing. Override this hook to change the behavior. See also
198:meth:`~lightning.pytorch.core.hooks.ModelHooks.on_test_model_eval`.
199
200"""
201# The loop won't call this hook unless it is overridden. The line below is here in case the user calls super().
202self.trainer.model.train()
203
204def on_predict_model_eval(self) -> None:
205"""Called when the predict loop starts.
206
207The predict loop by default calls ``.eval()`` on the LightningModule before it starts. Override this hook
208to change the behavior.
209
210"""
211self.trainer.model.eval()
212
213def on_train_epoch_start(self) -> None:
214"""Called in the training loop at the very beginning of the epoch."""
215
216def on_train_epoch_end(self) -> None:
217"""Called in the training loop at the very end of the epoch.
218
219To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
220:class:`~lightning.pytorch.LightningModule` and access them in this hook:
221
222.. code-block:: python
223
224class MyLightningModule(L.LightningModule):
225def __init__(self):
226super().__init__()
227self.training_step_outputs = []
228
229def training_step(self):
230loss = ...
231self.training_step_outputs.append(loss)
232return loss
233
234def on_train_epoch_end(self):
235# do something with all training_step outputs, for example:
236epoch_mean = torch.stack(self.training_step_outputs).mean()
237self.log("training_epoch_mean", epoch_mean)
238# free up the memory
239self.training_step_outputs.clear()
240
241"""
242
243def on_validation_epoch_start(self) -> None:
244"""Called in the validation loop at the very beginning of the epoch."""
245
246def on_validation_epoch_end(self) -> None:
247"""Called in the validation loop at the very end of the epoch."""
248
249def on_test_epoch_start(self) -> None:
250"""Called in the test loop at the very beginning of the epoch."""
251
252def on_test_epoch_end(self) -> None:
253"""Called in the test loop at the very end of the epoch."""
254
255def on_predict_epoch_start(self) -> None:
256"""Called at the beginning of predicting."""
257
258def on_predict_epoch_end(self) -> None:
259"""Called at the end of predicting."""
260
261def on_before_zero_grad(self, optimizer: Optimizer) -> None:
262"""Called after ``training_step()`` and before ``optimizer.zero_grad()``.
263
264Called in the training loop after taking an optimizer step and before zeroing grads.
265Good place to inspect weight information with weights updated.
266
267This is where it is called::
268
269for optimizer in optimizers:
270out = training_step(...)
271
272model.on_before_zero_grad(optimizer) # < ---- called here
273optimizer.zero_grad()
274
275backward()
276
277Args:
278optimizer: The optimizer for which grads should be zeroed.
279
280"""
281
282def on_before_backward(self, loss: Tensor) -> None:
283"""Called before ``loss.backward()``.
284
285Args:
286loss: Loss divided by number of batches for gradient accumulation and scaled if using AMP.
287
288"""
289pass
290
291def on_after_backward(self) -> None:
292"""Called after ``loss.backward()`` and before optimizers are stepped.
293
294Note:
295If using native AMP, the gradients will not be unscaled at this point.
296Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
297
298"""
299
300def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
301"""Called before ``optimizer.step()``.
302
303If using gradient accumulation, the hook is called once the gradients have been accumulated.
304See: :paramref:`~lightning.pytorch.trainer.trainer.Trainer.accumulate_grad_batches`.
305
306If using AMP, the loss will be unscaled before calling this hook.
307See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
308for more information on the scaling of gradients.
309
310If clipping gradients, the gradients will not have been clipped yet.
311
312Args:
313optimizer: Current optimizer being used.
314
315Example::
316
317def on_before_optimizer_step(self, optimizer):
318# example to inspect gradient information in tensorboard
319if self.trainer.global_step % 25 == 0: # don't make the tf file huge
320for k, v in self.named_parameters():
321self.logger.experiment.add_histogram(
322tag=k, values=v.grad, global_step=self.trainer.global_step
323)
324
325"""
326
327def configure_sharded_model(self) -> None:
328"""Deprecated.
329
330Use :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` instead.
331
332"""
333
334def configure_model(self) -> None:
335"""Hook to create modules in a strategy and precision aware context.
336
337This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we'd like to shard
338the model instantly to save memory and initialization time.
339For non-sharded strategies, you can choose to override this hook or to initialize your model under the
340:meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager.
341
342This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
343implementation of this hook is **idempotent**, i.e., after the first time the hook is called, subsequent calls
344to it should be a no-op.
345
346"""
347
348
349class DataHooks:
350"""Hooks to be used for data related stuff."""
351
352def __init__(self) -> None:
353"""
354Attributes:
355prepare_data_per_node:
356If True, each LOCAL_RANK=0 will call prepare data.
357Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
358allow_zero_length_dataloader_with_multiple_devices:
359If True, dataloader with zero length within local rank is allowed.
360Default value is False.
361"""
362super().__init__()
363self.prepare_data_per_node: bool = True
364self.allow_zero_length_dataloader_with_multiple_devices: bool = False
365
366def prepare_data(self) -> None:
367"""Use this to download and prepare data. Downloading and saving data with multiple processes (distributed
368settings) will result in corrupted data. Lightning ensures this method is called only within a single process,
369so you can safely add your downloading logic within.
370
371.. warning:: DO NOT set state to the model (use ``setup`` instead)
372since this is NOT called on every device
373
374Example::
375
376def prepare_data(self):
377# good
378download_data()
379tokenize()
380etc()
381
382# bad
383self.split = data_split
384self.some_state = some_other_state()
385
386In a distributed environment, ``prepare_data`` can be called in two ways
387(using :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`)
388
3891. Once per node. This is the default and is only called on LOCAL_RANK=0.
3902. Once in total. Only called on GLOBAL_RANK=0.
391
392Example::
393
394# DEFAULT
395# called once per node on LOCAL_RANK=0 of that node
396class LitDataModule(LightningDataModule):
397def __init__(self):
398super().__init__()
399self.prepare_data_per_node = True
400
401
402# call on GLOBAL_RANK=0 (great for shared file systems)
403class LitDataModule(LightningDataModule):
404def __init__(self):
405super().__init__()
406self.prepare_data_per_node = False
407
408This is called before requesting the dataloaders:
409
410.. code-block:: python
411
412model.prepare_data()
413initialize_distributed()
414model.setup(stage)
415model.train_dataloader()
416model.val_dataloader()
417model.test_dataloader()
418model.predict_dataloader()
419
420"""
421
422def setup(self, stage: str) -> None:
423"""Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you
424need to build models dynamically or adjust something about them. This hook is called on every process when
425using DDP.
426
427Args:
428stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
429
430Example::
431
432class LitModel(...):
433def __init__(self):
434self.l1 = None
435
436def prepare_data(self):
437download_data()
438tokenize()
439
440# don't do this
441self.something = else
442
443def setup(self, stage):
444data = load_data(...)
445self.l1 = nn.Linear(28, data.num_classes)
446
447"""
448
449def teardown(self, stage: str) -> None:
450"""Called at the end of fit (train + validate), validate, test, or predict.
451
452Args:
453stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
454
455"""
456
457def train_dataloader(self) -> TRAIN_DATALOADERS:
458"""An iterable or collection of iterables specifying training samples.
459
460For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
461
462The dataloader you return will not be reloaded unless you set
463:paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to
464a positive integer.
465
466For data processing use the following pattern:
467
468- download in :meth:`prepare_data`
469- process and split in :meth:`setup`
470
471However, the above are only necessary for distributed processing.
472
473.. warning:: do not assign state in prepare_data
474
475- :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
476- :meth:`prepare_data`
477- :meth:`setup`
478
479Note:
480Lightning tries to add the correct sampler for distributed and arbitrary hardware.
481There is no need to set it yourself.
482
483"""
484raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
485
486def test_dataloader(self) -> EVAL_DATALOADERS:
487r"""An iterable or collection of iterables specifying test samples.
488
489For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
490
491For data processing use the following pattern:
492
493- download in :meth:`prepare_data`
494- process and split in :meth:`setup`
495
496However, the above are only necessary for distributed processing.
497
498.. warning:: do not assign state in prepare_data
499
500
501- :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`
502- :meth:`prepare_data`
503- :meth:`setup`
504
505Note:
506Lightning tries to add the correct sampler for distributed and arbitrary hardware.
507There is no need to set it yourself.
508
509Note:
510If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
511this method.
512
513"""
514raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer")
515
516def val_dataloader(self) -> EVAL_DATALOADERS:
517r"""An iterable or collection of iterables specifying validation samples.
518
519For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
520
521The dataloader you return will not be reloaded unless you set
522:paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to
523a positive integer.
524
525It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
526
527- :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
528- :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`
529- :meth:`prepare_data`
530- :meth:`setup`
531
532Note:
533Lightning tries to add the correct sampler for distributed and arbitrary hardware
534There is no need to set it yourself.
535
536Note:
537If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
538implement this method.
539
540"""
541raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer")
542
543def predict_dataloader(self) -> EVAL_DATALOADERS:
544r"""An iterable or collection of iterables specifying prediction samples.
545
546For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
547
548It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
549
550- :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`
551- :meth:`prepare_data`
552- :meth:`setup`
553
554Note:
555Lightning tries to add the correct sampler for distributed and arbitrary hardware
556There is no need to set it yourself.
557
558Return:
559A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.
560
561"""
562raise MisconfigurationException(
563"`predict_dataloader` must be implemented to be used with the Lightning Trainer"
564)
565
566def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
567"""Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data
568structure.
569
570The data types listed below (and any arbitrary nesting of them) are supported out of the box:
571
572- :class:`torch.Tensor` or anything that implements `.to(...)`
573- :class:`list`
574- :class:`dict`
575- :class:`tuple`
576
577For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
578
579Note:
580This hook should only transfer the data and not modify it, nor should it move the data to
581any other device than the one passed in as argument (unless you know what you are doing).
582To check the current state of execution of this hook you can use
583``self.trainer.training/testing/validating/predicting`` so that you can
584add different logic as per your requirement.
585
586Args:
587batch: A batch of data that needs to be transferred to a new device.
588device: The target device as defined in PyTorch.
589dataloader_idx: The index of the dataloader to which the batch belongs.
590
591Returns:
592A reference to the data on the new device.
593
594Example::
595
596def transfer_batch_to_device(self, batch, device, dataloader_idx):
597if isinstance(batch, CustomBatch):
598# move all tensors in your custom data structure to the device
599batch.samples = batch.samples.to(device)
600batch.targets = batch.targets.to(device)
601elif dataloader_idx == 0:
602# skip device transfer for the first dataloader or anything you wish
603pass
604else:
605batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
606return batch
607
608See Also:
609- :meth:`move_data_to_device`
610- :meth:`apply_to_collection`
611
612"""
613return move_data_to_device(batch, device)
614
615def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
616"""Override to alter or apply batch augmentations to your batch before it is transferred to the device.
617
618Note:
619To check the current state of execution of this hook you can use
620``self.trainer.training/testing/validating/predicting`` so that you can
621add different logic as per your requirement.
622
623Args:
624batch: A batch of data that needs to be altered or augmented.
625dataloader_idx: The index of the dataloader to which the batch belongs.
626
627Returns:
628A batch of data
629
630Example::
631
632def on_before_batch_transfer(self, batch, dataloader_idx):
633batch['x'] = transforms(batch['x'])
634return batch
635
636See Also:
637- :meth:`on_after_batch_transfer`
638- :meth:`transfer_batch_to_device`
639
640"""
641return batch
642
643def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
644"""Override to alter or apply batch augmentations to your batch after it is transferred to the device.
645
646Note:
647To check the current state of execution of this hook you can use
648``self.trainer.training/testing/validating/predicting`` so that you can
649add different logic as per your requirement.
650
651Args:
652batch: A batch of data that needs to be altered or augmented.
653dataloader_idx: The index of the dataloader to which the batch belongs.
654
655Returns:
656A batch of data
657
658Example::
659
660def on_after_batch_transfer(self, batch, dataloader_idx):
661batch['x'] = gpu_transforms(batch['x'])
662return batch
663
664See Also:
665- :meth:`on_before_batch_transfer`
666- :meth:`transfer_batch_to_device`
667
668"""
669return batch
670
671
672class CheckpointHooks:
673"""Hooks to be used with Checkpointing."""
674
675def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
676r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is
677your chance to restore this.
678
679Args:
680checkpoint: Loaded checkpoint
681
682Example::
683
684def on_load_checkpoint(self, checkpoint):
685# 99% of the time you don't need to implement this method
686self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
687
688Note:
689Lightning auto-restores global step, epoch, and train state including amp scaling.
690There is no need for you to restore anything regarding training.
691
692"""
693
694def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
695r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to
696save.
697
698Args:
699checkpoint: The full checkpoint dictionary before it gets dumped to a file.
700Implementations of this hook can insert additional data into this dictionary.
701
702Example::
703
704def on_save_checkpoint(self, checkpoint):
705# 99% of use cases you don't need to implement this method
706checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
707
708Note:
709Lightning saves all aspects of training (epoch, global step, etc...)
710including amp scaling.
711There is no need for you to store anything about training.
712
713"""
714