pytorch-lightning
650 строк · 25.1 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15from abc import ABC, abstractmethod
16from contextlib import contextmanager, nullcontext
17from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
18
19import torch
20from torch import Tensor
21from torch.nn import Module
22from torch.optim import Optimizer
23
24import lightning.pytorch as pl
25from lightning.fabric.plugins import CheckpointIO
26from lightning.fabric.strategies import _StrategyRegistry
27from lightning.fabric.utilities import move_data_to_device
28from lightning.fabric.utilities.distributed import ReduceOp
29from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
30from lightning.fabric.utilities.init import _EmptyInit
31from lightning.fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device
32from lightning.fabric.utilities.types import _PATH
33from lightning.pytorch.core.optimizer import LightningOptimizer, _init_optimizers_and_lr_schedulers
34from lightning.pytorch.plugins import TorchCheckpointIO
35from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
36from lightning.pytorch.plugins.precision import Precision
37from lightning.pytorch.strategies.launchers.launcher import _Launcher
38from lightning.pytorch.trainer.states import TrainerFn
39from lightning.pytorch.utilities.types import STEP_OUTPUT, LRSchedulerConfig
40
41TBroadcast = TypeVar("TBroadcast")
42TReduce = TypeVar("TReduce")
43
44log = logging.getLogger(__name__)
45
46
47class Strategy(ABC):
48"""Base class for all strategies that change the behaviour of the training, validation and test- loop."""
49
50def __init__(
51self,
52accelerator: Optional["pl.accelerators.Accelerator"] = None,
53checkpoint_io: Optional[CheckpointIO] = None,
54precision_plugin: Optional[Precision] = None,
55) -> None:
56self._accelerator: Optional["pl.accelerators.Accelerator"] = accelerator
57self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io
58self._precision_plugin: Optional[Precision] = None
59# Call the precision setter for input validation
60self.precision_plugin = precision_plugin # type: ignore[assignment]
61self._lightning_module: Optional[pl.LightningModule] = None
62self._model: Optional[Module] = None
63self._launcher: Optional[_Launcher] = None
64self._forward_redirection: _ForwardRedirection = _ForwardRedirection()
65self._optimizers: List[Optimizer] = []
66self._lightning_optimizers: List[LightningOptimizer] = []
67self.lr_scheduler_configs: List[LRSchedulerConfig] = []
68
69@property
70def launcher(self) -> Optional[_Launcher]:
71return self._launcher
72
73@property
74def accelerator(self) -> Optional["pl.accelerators.Accelerator"]:
75return self._accelerator
76
77@accelerator.setter
78def accelerator(self, accelerator: "pl.accelerators.Accelerator") -> None:
79self._accelerator = accelerator
80
81@property
82def checkpoint_io(self) -> CheckpointIO:
83if self._checkpoint_io is None:
84self._checkpoint_io = TorchCheckpointIO()
85elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
86self._checkpoint_io.checkpoint_io = TorchCheckpointIO()
87
88return self._checkpoint_io
89
90@checkpoint_io.setter
91def checkpoint_io(self, io: CheckpointIO) -> None:
92self._checkpoint_io = io
93
94@property
95def precision_plugin(self) -> Precision:
96return self._precision_plugin if self._precision_plugin is not None else Precision()
97
98@precision_plugin.setter
99def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
100self._precision_plugin = precision_plugin
101
102@property
103def optimizers(self) -> List[Optimizer]:
104return self._optimizers
105
106@optimizers.setter
107def optimizers(self, optimizers: List[Optimizer]) -> None:
108self._optimizers = optimizers
109self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers]
110
111def connect(self, model: "pl.LightningModule") -> None:
112"""Called by the Trainer to connect the strategy with the model."""
113# model conversions cannot be applied at this point because `LightningModule.{setup,configure_model}` haven't
114# run yet
115self._lightning_module = model
116self.model = model
117
118def _configure_launcher(self) -> None:
119"""Attach the launcher based on Strategy."""
120
121def setup_environment(self) -> None:
122"""Setup any processes or distributed connections.
123
124This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
125environment before setup is complete.
126
127"""
128assert self.accelerator is not None
129self.accelerator.setup_device(self.root_device)
130
131def setup_optimizers(self, trainer: "pl.Trainer") -> None:
132"""Creates optimizers and schedulers.
133
134Args:
135trainer: the Trainer, these optimizers should be connected to
136
137"""
138assert self.lightning_module is not None
139self.optimizers, self.lr_scheduler_configs = _init_optimizers_and_lr_schedulers(self.lightning_module)
140
141def setup(self, trainer: "pl.Trainer") -> None:
142"""Sets up the accelerator, plugins and initializes the optimizers (if needed).
143
144Args:
145trainer: the trainer instance
146
147"""
148assert self.accelerator is not None
149self.accelerator.setup(trainer)
150
151assert self.model is not None
152# let the precision plugin convert the module here so that this strategy hook can decide the order
153# of operations
154self.model = self.precision_plugin.convert_module(self.model)
155self.model_to_device()
156self.model = self._setup_model(self.model)
157
158if trainer.state.fn == TrainerFn.FITTING:
159self.setup_optimizers(trainer)
160self.setup_precision_plugin()
161if trainer.state.fn == TrainerFn.FITTING:
162_optimizers_to_device(self.optimizers, self.root_device)
163
164def setup_precision_plugin(self) -> None:
165"""Attaches the precision plugin to the strategy."""
166assert self.model is not None
167model, optimizers, lr_scheduler_configs = self.precision_plugin.connect(
168self.model, self.optimizers, self.lr_scheduler_configs
169)
170self.model = model
171self.optimizers = optimizers
172self.lr_scheduler_configs = lr_scheduler_configs
173
174def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
175"""Returns state of an optimizer.
176
177Allows for syncing/collating optimizer state from processes in custom strategies.
178
179"""
180if isinstance(optimizer, LightningOptimizer):
181optimizer = optimizer._optimizer
182
183if hasattr(optimizer, "consolidate_state_dict"):
184# there are optimizers like PyTorch's ZeroRedundancyOptimizer that shard their
185# states, and to avoid OOM we consolidate the full state on rank 0 only
186optimizer.consolidate_state_dict()
187return optimizer.state_dict() if self.is_global_zero else {}
188
189# for optimizers that are not sharded, we return the state dict on all ranks
190return optimizer.state_dict()
191
192def backward(
193self,
194closure_loss: Tensor,
195optimizer: Optional[Optimizer],
196*args: Any,
197**kwargs: Any,
198) -> Tensor:
199r"""Forwards backward-calls to the precision plugin.
200
201Args:
202closure_loss: a tensor holding the loss value to backpropagate
203optimizer: An optional optimizer that gets passed down to the precision plugin's backward
204\*args: Positional arguments that get passed down to the precision plugin's backward, intended as arguments
205for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`.
206\**kwargs: Keyword arguments for the same purpose as ``*args``.
207
208"""
209self.pre_backward(closure_loss)
210assert self.lightning_module is not None
211closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module)
212
213self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
214
215closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module)
216self.post_backward(closure_loss)
217
218return closure_loss
219
220def optimizer_step(
221self,
222optimizer: Optimizer,
223closure: Callable[[], Any],
224model: Optional[Union["pl.LightningModule", Module]] = None,
225**kwargs: Any,
226) -> Any:
227r"""Performs the actual optimizer step.
228
229Args:
230optimizer: the optimizer performing the step
231closure: closure calculating the loss value
232model: reference to the model, optionally defining optimizer step related hooks
233\**kwargs: Keyword arguments to ``optimizer.step``
234
235"""
236model = model or self.lightning_module
237# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
238assert isinstance(model, pl.LightningModule)
239return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
240
241def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
242"""Setup a model and multiple optimizers together.
243
244The returned objects are expected to be in the same order they were passed in. The default implementation will
245call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
246
247"""
248# TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324
249model = self._setup_model(model)
250optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
251return model, optimizers
252
253def _setup_model(self, model: Module) -> Module:
254"""Performs setup for the model, e.g., by wrapping it by another class."""
255# TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324
256return model
257
258def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
259"""Performs setup for the optimizer, e.g., by wrapping it by another class."""
260# TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324
261return optimizer
262
263def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
264"""Moves the batch to the correct device.
265
266The returned batch is of the same type as the input batch, just
267having all tensors on the correct device.
268
269Args:
270batch: The batch of samples to move to the correct device
271device: The target device
272dataloader_idx: The index of the dataloader to which the batch belongs.
273
274"""
275model = self.lightning_module
276device = device or self.root_device
277if model is not None:
278return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
279return move_data_to_device(batch, device)
280
281@property
282@abstractmethod
283def root_device(self) -> torch.device:
284"""Returns the root device."""
285
286@abstractmethod
287def model_to_device(self) -> None:
288"""Moves the model to the correct device."""
289
290@property
291@abstractmethod
292def is_global_zero(self) -> bool:
293"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""
294
295@abstractmethod
296def reduce(
297self,
298tensor: Union[Tensor, Any],
299group: Optional[Any] = None,
300reduce_op: Optional[Union[ReduceOp, str]] = "mean",
301) -> Union[Tensor, Any]:
302"""Reduces the given tensor (e.g. across GPUs/processes).
303
304Args:
305tensor: the tensor to sync and reduce
306group: the process group to reduce
307reduce_op: the reduction operation. Defaults to 'mean'.
308Can also be a string 'sum' or ReduceOp.
309
310"""
311
312@abstractmethod
313def barrier(self, name: Optional[str] = None) -> None:
314"""Synchronizes all processes which blocks processes until the whole group enters this function.
315
316Args:
317name: an optional name to pass into barrier.
318
319"""
320
321@abstractmethod
322def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
323"""Broadcasts an object to all processes.
324
325Args:
326obj: the object to broadcast
327src: source rank
328
329"""
330
331@abstractmethod
332def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
333"""Perform an all_gather on all processes.
334
335Args:
336tensor: the tensor to all_gather
337group: the process group to gather results from
338sync_grads: flag that allows users to synchronize gradients for all_gather op
339
340"""
341
342def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
343"""Reduce a boolean decision across all processes."""
344return decision
345
346def pre_backward(self, closure_loss: Tensor) -> None:
347"""Run before precision plugin executes backward."""
348
349def post_backward(self, closure_loss: Tensor) -> None:
350"""Run after precision plugin executes backward."""
351
352@property
353def model(self) -> Optional[Module]:
354"""Returns the potentially wrapped LightningModule."""
355return self._model if self._model is not None else self._lightning_module
356
357@model.setter
358def model(self, new_model: Optional[Module]) -> None:
359self._model = new_model
360
361@property
362def lightning_module(self) -> Optional["pl.LightningModule"]:
363"""Returns the pure LightningModule without potential wrappers."""
364return self._lightning_module
365
366def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
367torch.cuda.empty_cache()
368return self.checkpoint_io.load_checkpoint(checkpoint_path)
369
370def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
371assert self.lightning_module is not None
372self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
373
374def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
375optimizer_states = checkpoint["optimizer_states"]
376for optimizer, opt_state in zip(self.optimizers, optimizer_states):
377optimizer.load_state_dict(opt_state)
378_optimizer_to_device(optimizer, self.root_device)
379
380def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
381"""The actual training step.
382
383See :meth:`~lightning.pytorch.core.LightningModule.training_step` for more details
384
385"""
386assert self.lightning_module is not None
387assert self.model is not None
388with self.precision_plugin.train_step_context():
389if self.model != self.lightning_module:
390return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
391return self.lightning_module.training_step(*args, **kwargs)
392
393def post_training_step(self) -> None:
394"""This hook is deprecated.
395
396Override :meth:`training_step` instead.
397
398"""
399pass
400
401def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
402"""The actual validation step.
403
404See :meth:`~lightning.pytorch.core.LightningModule.validation_step` for more details
405
406"""
407assert self.lightning_module is not None
408assert self.model is not None
409with self.precision_plugin.val_step_context():
410if self.model != self.lightning_module:
411return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
412return self.lightning_module.validation_step(*args, **kwargs)
413
414def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
415"""The actual test step.
416
417See :meth:`~lightning.pytorch.core.LightningModule.test_step` for more details
418
419"""
420assert self.lightning_module is not None
421assert self.model is not None
422with self.precision_plugin.test_step_context():
423if self.model != self.lightning_module:
424return self._forward_redirection(self.model, self.lightning_module, "test_step", *args, **kwargs)
425return self.lightning_module.test_step(*args, **kwargs)
426
427def predict_step(self, *args: Any, **kwargs: Any) -> Any:
428"""The actual predict step.
429
430See :meth:`~lightning.pytorch.core.LightningModule.predict_step` for more details
431
432"""
433assert self.lightning_module is not None
434assert self.model is not None
435with self.precision_plugin.predict_step_context():
436if self.model != self.lightning_module:
437return self._forward_redirection(self.model, self.lightning_module, "predict_step", *args, **kwargs)
438return self.lightning_module.predict_step(*args, **kwargs)
439
440def process_dataloader(self, dataloader: object) -> object:
441"""Wraps the dataloader if necessary.
442
443Args:
444dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
445
446"""
447return dataloader
448
449@property
450def restore_checkpoint_after_setup(self) -> bool:
451"""Override to delay restoring from checkpoint till after the setup phase has completed. This is useful when
452the strategy requires all the setup hooks to run before loading checkpoint.
453
454Returns:
455If ``True``, restore checkpoint after strategy setup.
456
457"""
458return False
459
460@property
461def lightning_restore_optimizer(self) -> bool:
462"""Override to disable Lightning restoring optimizers/schedulers.
463
464This is useful for strategies which manage restoring optimizers/schedulers.
465
466"""
467return True
468
469@property
470def handles_gradient_accumulation(self) -> bool:
471"""Whether the strategy handles gradient accumulation internally."""
472return False
473
474def lightning_module_state_dict(self) -> Dict[str, Any]:
475"""Returns model state."""
476assert self.lightning_module is not None
477return self.lightning_module.state_dict()
478
479def save_checkpoint(
480self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
481) -> None:
482"""Save model/training states as a checkpoint file through state-dump and file-write.
483
484Args:
485checkpoint: dict containing model and trainer state
486filepath: write-target file's path
487storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
488
489"""
490if self.is_global_zero:
491self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
492
493def remove_checkpoint(self, filepath: _PATH) -> None:
494"""Remove checkpoint filepath from the filesystem.
495
496Args:
497filepath: Path to checkpoint
498
499"""
500if self.is_global_zero:
501self.checkpoint_io.remove_checkpoint(filepath)
502
503@contextmanager
504def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
505"""Controls how tensors get created (device, dtype).
506
507Args:
508empty_init: Whether to initialize the model with empty weights (uninitialized memory).
509If ``None``, the strategy will decide. Some strategies may not support all options.
510
511"""
512device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
513empty_init_context = _EmptyInit(enabled=bool(empty_init))
514with empty_init_context, device_context, self.precision_plugin.tensor_init_context():
515yield
516
517@contextmanager
518def model_sharded_context(self) -> Generator[None, None, None]:
519"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard
520the model instantly, which is useful for extremely large models which can save memory and initialization time.
521
522Returns: Model parallel context.
523
524"""
525yield
526
527def teardown(self) -> None:
528"""This method is called to teardown the training process.
529
530It is the right place to release memory and free other resources.
531
532"""
533_optimizers_to_device(self.optimizers, torch.device("cpu"))
534
535if self.lightning_module is not None:
536log.debug(f"{self.__class__.__name__}: moving model to CPU")
537self.lightning_module.cpu()
538self.precision_plugin.teardown()
539assert self.accelerator is not None
540self.accelerator.teardown()
541self.checkpoint_io.teardown()
542
543@classmethod
544def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
545pass
546
547def on_train_start(self) -> None:
548"""Called when train begins."""
549pass
550
551def on_validation_start(self) -> None:
552"""Called when validation begins."""
553pass
554
555def on_test_start(self) -> None:
556"""Called when test begins."""
557pass
558
559def on_predict_start(self) -> None:
560"""Called when predict begins."""
561pass
562
563def on_train_end(self) -> None:
564"""Called when train ends."""
565pass
566
567def on_validation_end(self) -> None:
568"""Called when validation ends."""
569pass
570
571def on_test_end(self) -> None:
572"""Called when test end."""
573pass
574
575def on_predict_end(self) -> None:
576"""Called when predict ends."""
577pass
578
579def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
580"""Called in the training loop before anything happens for that batch."""
581pass
582
583def on_exception(self, exception: BaseException) -> None:
584"""Called when the trainer execution is interrupted by an exception."""
585pass
586
587def _reset_optimizers_and_schedulers(self) -> None:
588self._optimizers = []
589self._lightning_optimizers = []
590self.lr_scheduler_configs = []
591
592def __getstate__(self) -> Dict:
593# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
594state = dict(vars(self)) # copy
595state["_lightning_optimizers"] = []
596return state
597
598def __setstate__(self, state: Dict) -> None:
599self.__dict__ = state
600self.optimizers = self.optimizers # re-create the `_lightning_optimizers`
601
602
603class _ForwardRedirection:
604"""Implements the `forward-redirection`.
605
606A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
607
608"""
609
610def __call__(
611self, wrapper_module: Module, original_module: "pl.LightningModule", method_name: str, *args: Any, **kwargs: Any
612) -> STEP_OUTPUT:
613"""Reroutes a method call through the `wrapper_module`'s `forward` method.
614
615Args:
616wrapper_module: The module that has `original_module` wrapped.
617original_module: The module that was wrapped inside `wrapper_module`.
618method_name: The name of the method that should be called on the `original_module` after inputs get
619redirected through the `wrapper_module`'s `forward` method.
620*args: The positional arguments to the method `method_name`. They will get passed to a patched
621`forward` method instead.
622**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
623`forward` method instead.
624
625"""
626assert method_name != "forward"
627original_forward = original_module.forward
628
629def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
630# Unpatch ourselves immediately before calling the method `method_name`
631# because itself may want to call the real `forward`
632original_module.forward = original_forward # type: ignore[method-assign]
633# Call the actual method e.g. `.training_step(...)`
634method = getattr(original_module, method_name)
635out = method(*_args, **_kwargs)
636self.on_after_inner_forward(wrapper_module, original_module)
637return out
638
639# Patch the original_module's forward so we can redirect the arguments back to the real method
640original_module.forward = wrapped_forward # type: ignore[method-assign]
641
642wrapper_output = wrapper_module(*args, **kwargs)
643self.on_after_outer_forward(wrapper_module, original_module)
644return wrapper_output
645
646def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
647pass
648
649def on_after_outer_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
650pass
651