pytorch-lightning
1633 строки · 69.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The LightningModule - an nn.Module with many additional features."""
15
16import logging
17import numbers
18import weakref
19from contextlib import contextmanager
20from pathlib import Path
21from typing import (
22IO,
23Any,
24Callable,
25Dict,
26Generator,
27List,
28Literal,
29Mapping,
30Optional,
31Sequence,
32Tuple,
33Union,
34cast,
35overload,
36)
37
38import torch
39from lightning_utilities.core.apply_func import apply_to_collection
40from lightning_utilities.core.imports import RequirementCache
41from torch import ScriptModule, Tensor
42from torch.nn import Module
43from torch.optim.optimizer import Optimizer
44from torchmetrics import Metric, MetricCollection
45from typing_extensions import Self, override
46
47import lightning.fabric as lf
48import lightning.pytorch as pl
49from lightning.fabric.loggers import Logger as FabricLogger
50from lightning.fabric.utilities.apply_func import convert_to_tensors
51from lightning.fabric.utilities.cloud_io import get_filesystem
52from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
53from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
54from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
55from lightning.fabric.wrappers import _FabricOptimizer
56from lightning.pytorch.callbacks.callback import Callback
57from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
58from lightning.pytorch.core.mixins import HyperparametersMixin
59from lightning.pytorch.core.optimizer import LightningOptimizer
60from lightning.pytorch.core.saving import _load_from_checkpoint
61from lightning.pytorch.loggers import Logger
62from lightning.pytorch.trainer import call
63from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator
64from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype
65from lightning.pytorch.utilities import GradClipAlgorithmType
66from lightning.pytorch.utilities.exceptions import MisconfigurationException
67from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
68from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
69from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_debug, rank_zero_warn
70from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
71from lightning.pytorch.utilities.types import (
72_METRIC,
73STEP_OUTPUT,
74LRSchedulerPLType,
75LRSchedulerTypeUnion,
76OptimizerLRScheduler,
77)
78
79_ONNX_AVAILABLE = RequirementCache("onnx")
80
81warning_cache = WarningCache()
82log = logging.getLogger(__name__)
83
84MODULE_OPTIMIZERS = Union[
85Optimizer, LightningOptimizer, _FabricOptimizer, List[Optimizer], List[LightningOptimizer], List[_FabricOptimizer]
86]
87
88
89class LightningModule(
90_DeviceDtypeModuleMixin,
91HyperparametersMixin,
92ModelHooks,
93DataHooks,
94CheckpointHooks,
95Module,
96):
97# Below is for property support of JIT
98# since none of these are important when using JIT, we are going to ignore them.
99__jit_unused_properties__: List[str] = (
100[
101"example_input_array",
102"on_gpu",
103"current_epoch",
104"global_step",
105"global_rank",
106"local_rank",
107"logger",
108"loggers",
109"automatic_optimization",
110"trainer",
111"fabric",
112"strict_loading",
113]
114+ _DeviceDtypeModuleMixin.__jit_unused_properties__
115+ HyperparametersMixin.__jit_unused_properties__
116)
117_jit_is_scripting = False
118
119CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
120CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
121CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
122
123def __init__(self, *args: Any, **kwargs: Any) -> None:
124super().__init__(*args, **kwargs)
125
126# pointer to the trainer object
127self._trainer: Optional["pl.Trainer"] = None
128
129# attributes that can be set by user
130self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None
131self._automatic_optimization: bool = True
132self._strict_loading: Optional[bool] = None
133
134# attributes used internally
135self._current_fx_name: Optional[str] = None
136self._param_requires_grad_state: Dict[str, bool] = {}
137self._metric_attributes: Optional[Dict[int, str]] = None
138self._register_sharded_tensor_state_dict_hooks_if_available()
139self._compiler_ctx: Optional[Dict[str, Any]] = None
140
141# attributes only used when using fabric
142self._fabric: Optional["lf.Fabric"] = None
143self._fabric_optimizers: List[_FabricOptimizer] = []
144
145@overload
146def optimizers(
147self, use_pl_optimizer: Literal[True] = True
148) -> Union[LightningOptimizer, List[LightningOptimizer]]: ...
149
150@overload
151def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: ...
152
153@overload
154def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ...
155
156def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
157"""Returns the optimizer(s) that are being used during training. Useful for manual optimization.
158
159Args:
160use_pl_optimizer: If ``True``, will wrap the optimizer(s) in a
161:class:`~lightning.pytorch.core.optimizer.LightningOptimizer` for automatic handling of precision,
162profiling, and counting of step calls for proper logging and checkpointing. It specifically wraps the
163``step`` method and custom optimizers that don't have this method are not supported.
164
165Returns:
166A single optimizer, or a list of optimizers in case multiple ones are present.
167
168"""
169if self._fabric:
170opts: MODULE_OPTIMIZERS = self._fabric_optimizers
171elif use_pl_optimizer:
172opts = self.trainer.strategy._lightning_optimizers
173else:
174opts = self.trainer.optimizers
175
176# single optimizer
177if (
178isinstance(opts, list)
179and len(opts) == 1
180and isinstance(opts[0], (Optimizer, LightningOptimizer, _FabricOptimizer))
181):
182return opts[0]
183# multiple opts
184return opts
185
186def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]:
187"""Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization.
188
189Returns:
190A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no
191schedulers were returned in :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`.
192
193"""
194if not self.trainer.lr_scheduler_configs:
195return None
196
197# ignore other keys "interval", "frequency", etc.
198lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs]
199
200# single scheduler
201if len(lr_schedulers) == 1:
202return lr_schedulers[0]
203
204# multiple schedulers
205return lr_schedulers
206
207@property
208def trainer(self) -> "pl.Trainer":
209if self._fabric is not None:
210return _TrainerFabricShim(fabric=self._fabric) # type: ignore[return-value]
211if not self._jit_is_scripting and self._trainer is None:
212raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
213return self._trainer # type: ignore[return-value]
214
215@trainer.setter
216def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
217for v in self.children():
218if isinstance(v, LightningModule):
219v.trainer = trainer # type: ignore[assignment]
220# https://github.com/pytorch/pytorch/issues/95857
221if not _TORCH_GREATER_EQUAL_2_0 and trainer is not None and not isinstance(trainer, weakref.ProxyTypes):
222trainer = weakref.proxy(trainer)
223self._trainer = trainer
224
225@property
226def fabric(self) -> Optional["lf.Fabric"]:
227return self._fabric
228
229@fabric.setter
230def fabric(self, fabric: Optional["lf.Fabric"]) -> None:
231for v in self.children():
232if isinstance(v, LightningModule):
233v.fabric = fabric
234if fabric is not None and not isinstance(fabric, weakref.ProxyTypes):
235fabric = weakref.proxy(fabric)
236self._fabric = fabric
237
238@property
239def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]:
240"""The example input array is a specification of what the module can consume in the :meth:`forward` method. The
241return type is interpreted as follows:
242
243- Single tensor: It is assumed the model takes a single argument, i.e.,
244``model.forward(model.example_input_array)``
245- Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
246``model.forward(*model.example_input_array)``
247- Dict: The input array represents named keyword arguments, i.e.,
248``model.forward(**model.example_input_array)``
249
250"""
251return self._example_input_array
252
253@example_input_array.setter
254def example_input_array(self, example: Optional[Union[Tensor, Tuple, Dict]]) -> None:
255self._example_input_array = example
256
257@property
258def current_epoch(self) -> int:
259"""The current epoch in the ``Trainer``, or 0 if not attached."""
260return self.trainer.current_epoch if self._trainer else 0
261
262@property
263def global_step(self) -> int:
264"""Total training batches seen across all epochs.
265
266If no Trainer is attached, this propery is 0.
267
268"""
269return self.trainer.global_step if self._trainer else 0
270
271@property
272def global_rank(self) -> int:
273"""The index of the current process across all nodes and devices."""
274return self.trainer.global_rank if self._trainer else 0
275
276@property
277def local_rank(self) -> int:
278"""The index of the current process within a single node."""
279return self.trainer.local_rank if self._trainer else 0
280
281@property
282def on_gpu(self) -> bool:
283"""Returns ``True`` if this model is currently located on a GPU.
284
285Useful to set flags around the LightningModule for different CPU vs GPU behavior.
286
287"""
288return self.device.type == "cuda"
289
290@property
291def automatic_optimization(self) -> bool:
292"""If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293return self._automatic_optimization
294
295@automatic_optimization.setter
296def automatic_optimization(self, automatic_optimization: bool) -> None:
297self._automatic_optimization = automatic_optimization
298
299@property
300def strict_loading(self) -> bool:
301"""Determines how Lightning loads this model using `.load_state_dict(..., strict=model.strict_loading)`."""
302# We use None as the default internally to determine whether the user has set a value
303return self._strict_loading in (None, True)
304
305@strict_loading.setter
306def strict_loading(self, strict_loading: bool) -> None:
307self._strict_loading = strict_loading
308
309@property
310def logger(self) -> Optional[Union[Logger, FabricLogger]]:
311"""Reference to the logger object in the Trainer."""
312if self._fabric is not None:
313return self._fabric.logger
314return self._trainer.logger if self._trainer is not None else None
315
316@property
317def loggers(self) -> Union[List[Logger], List[FabricLogger]]:
318"""Reference to the list of loggers in the Trainer."""
319if self._fabric is not None:
320return self._fabric.loggers
321if self._trainer is not None:
322return self._trainer.loggers
323return []
324
325def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
326trainer = self._trainer
327if trainer:
328datahook_selector = trainer._data_connector._datahook_selector
329assert datahook_selector is not None
330obj = datahook_selector.get_instance(hook_name)
331if isinstance(obj, self.__class__):
332trainer_method = call._call_lightning_module_hook
333else:
334trainer_method = call._call_lightning_datamodule_hook
335
336return trainer_method(trainer, hook_name, *args)
337hook = getattr(self, hook_name)
338return hook(*args)
339
340def _on_before_batch_transfer(self, batch: Any, dataloader_idx: int = 0) -> Any:
341return self._call_batch_hook("on_before_batch_transfer", batch, dataloader_idx)
342
343def _apply_batch_transfer_handler(
344self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
345) -> Any:
346device = device or self.device
347batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
348batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
349return batch
350
351def print(self, *args: Any, **kwargs: Any) -> None:
352r"""Prints only from process 0. Use this in any distributed mode to log only once.
353
354Args:
355*args: The thing to print. The same as for Python's built-in print function.
356**kwargs: The same as for Python's built-in print function.
357
358Example::
359
360def forward(self, x):
361self.print(x, 'in forward')
362
363"""
364if self.trainer.is_global_zero:
365progress_bar = self.trainer.progress_bar_callback
366if progress_bar is not None and progress_bar.is_enabled:
367progress_bar.print(*args, **kwargs)
368else:
369print(*args, **kwargs)
370
371def log(
372self,
373name: str,
374value: _METRIC,
375prog_bar: bool = False,
376logger: Optional[bool] = None,
377on_step: Optional[bool] = None,
378on_epoch: Optional[bool] = None,
379reduce_fx: Union[str, Callable] = "mean",
380enable_graph: bool = False,
381sync_dist: bool = False,
382sync_dist_group: Optional[Any] = None,
383add_dataloader_idx: bool = True,
384batch_size: Optional[int] = None,
385metric_attribute: Optional[str] = None,
386rank_zero_only: bool = False,
387) -> None:
388"""Log a key, value pair.
389
390Example::
391
392self.log('train_loss', loss)
393
394The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.
395
396Args:
397name: key to log.
398value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
399prog_bar: if ``True`` logs to the progress bar.
400logger: if ``True`` logs to the logger.
401on_step: if ``True`` logs at this step. The default value is determined by the hook.
402See :ref:`extensions/logging:Automatic Logging` for details.
403on_epoch: if ``True`` logs epoch accumulated metrics. The default value is determined by the hook.
404See :ref:`extensions/logging:Automatic Logging` for details.
405reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
406enable_graph: if ``True``, will not auto detach the graph.
407sync_dist: if ``True``, reduces the metric across devices. Use with care as this may lead to a significant
408communication overhead.
409sync_dist_group: the DDP group to sync across.
410add_dataloader_idx: if ``True``, appends the index of the current dataloader to
411the name (when using multiple dataloaders). If False, user needs to give unique names for
412each dataloader to not mix the values.
413batch_size: Current batch_size. This will be directly inferred from the loaded batch,
414but for some data structures you might need to explicitly provide it.
415metric_attribute: To restore the metric state, Lightning requires the reference of the
416:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
417rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
418rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
419(e.g., early stopping). Warning: Improper use can lead to deadlocks! See
420:ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
421
422"""
423if self._fabric is not None:
424self._log_dict_through_fabric(dictionary={name: value}, logger=logger)
425return
426
427# check for invalid values
428apply_to_collection(value, dict, self.__check_not_nested, name)
429apply_to_collection(
430value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor)
431)
432
433trainer = self._trainer
434if trainer is None:
435# not an error to support testing the `*_step` methods without a `Trainer` reference
436rank_zero_warn(
437"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
438" This is most likely because the model hasn't been passed to the `Trainer`"
439)
440return
441if trainer.barebones:
442rank_zero_warn(
443"You are trying to `self.log()` but `Trainer(barebones=True)` is configured."
444" Logging can impact raw speed so it is disabled under this setting."
445)
446return
447results = trainer._results
448if results is None:
449raise MisconfigurationException(
450"You are trying to `self.log()` but the loop's result collection is not registered"
451" yet. This is most likely because you are trying to log in a `predict` hook,"
452" but it doesn't support logging"
453)
454if self._current_fx_name is None:
455raise MisconfigurationException(
456"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
457)
458
459on_step, on_epoch = _FxValidator.check_logging_and_get_default_levels(
460self._current_fx_name, on_step=on_step, on_epoch=on_epoch
461)
462
463# make sure user doesn't introduce logic for multi-dataloaders
464if "/dataloader_idx_" in name:
465raise MisconfigurationException(
466f"You called `self.log` with the key `{name}`"
467" but it should not contain information about `dataloader_idx`"
468)
469
470value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)
471
472if trainer._logger_connector.should_reset_tensors(self._current_fx_name):
473# if we started a new epoch (running its first batch) the hook name has changed
474# reset any tensors for the new hook name
475results.reset(metrics=False, fx=self._current_fx_name)
476
477if metric_attribute is None and isinstance(value, Metric):
478if self._metric_attributes is None:
479# compute once
480self._metric_attributes = {
481id(module): name for name, module in self.named_modules() if isinstance(module, Metric)
482}
483if not self._metric_attributes:
484raise MisconfigurationException(
485"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
486" You can fix this by setting an attribute for the metric in your `LightningModule`."
487)
488# try to find the passed metric in the LightningModule
489metric_attribute = self._metric_attributes.get(id(value), None)
490if metric_attribute is None:
491raise MisconfigurationException(
492"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
493f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
494f" of {list(self._metric_attributes.values())}"
495)
496
497if (
498trainer.training
499and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
500and batch_size is None
501):
502raise MisconfigurationException(
503"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
504)
505
506if logger and trainer.logger is None:
507rank_zero_warn(
508f"You called `self.log({name!r}, ..., logger=True)` but have no logger configured. You can enable one"
509" by doing `Trainer(logger=ALogger(...))`"
510)
511if logger is None:
512# we could set false here if there's no configured logger, however, we still need to compute the "logged"
513# metrics anyway because that's what the evaluation loops use as return value
514logger = True
515
516results.log(
517self._current_fx_name,
518name,
519value,
520prog_bar=prog_bar,
521logger=logger,
522on_step=on_step,
523on_epoch=on_epoch,
524reduce_fx=reduce_fx, # type: ignore[arg-type]
525enable_graph=enable_graph,
526add_dataloader_idx=add_dataloader_idx,
527batch_size=batch_size,
528sync_dist=sync_dist and trainer._accelerator_connector.is_distributed,
529sync_dist_fn=trainer.strategy.reduce,
530sync_dist_group=sync_dist_group,
531metric_attribute=metric_attribute,
532rank_zero_only=rank_zero_only,
533)
534
535trainer._logger_connector._current_fx = self._current_fx_name
536
537def log_dict(
538self,
539dictionary: Union[Mapping[str, _METRIC], MetricCollection],
540prog_bar: bool = False,
541logger: Optional[bool] = None,
542on_step: Optional[bool] = None,
543on_epoch: Optional[bool] = None,
544reduce_fx: Union[str, Callable] = "mean",
545enable_graph: bool = False,
546sync_dist: bool = False,
547sync_dist_group: Optional[Any] = None,
548add_dataloader_idx: bool = True,
549batch_size: Optional[int] = None,
550rank_zero_only: bool = False,
551) -> None:
552"""Log a dictionary of values at once.
553
554Example::
555
556values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
557self.log_dict(values)
558
559Args:
560dictionary: key value pairs.
561The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
562prog_bar: if ``True`` logs to the progress base.
563logger: if ``True`` logs to the logger.
564on_step: if ``True`` logs at this step.
565``None`` auto-logs for training_step but not validation/test_step.
566The default value is determined by the hook.
567See :ref:`extensions/logging:Automatic Logging` for details.
568on_epoch: if ``True`` logs epoch accumulated metrics.
569``None`` auto-logs for val/test step but not ``training_step``.
570The default value is determined by the hook.
571See :ref:`extensions/logging:Automatic Logging` for details.
572reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
573enable_graph: if ``True``, will not auto-detach the graph
574sync_dist: if ``True``, reduces the metric across GPUs/TPUs. Use with care as this may lead to a significant
575communication overhead.
576sync_dist_group: the ddp group to sync across.
577add_dataloader_idx: if ``True``, appends the index of the current dataloader to
578the name (when using multiple). If ``False``, user needs to give unique names for
579each dataloader to not mix values.
580batch_size: Current batch size. This will be directly inferred from the loaded batch,
581but some data structures might need to explicitly provide it.
582rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
583rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
584(e.g., early stopping). Warning: Improper use can lead to deadlocks! See
585:ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
586
587"""
588if self._fabric is not None:
589return self._log_dict_through_fabric(dictionary=dictionary, logger=logger)
590
591kwargs: Dict[str, bool] = {}
592
593if isinstance(dictionary, MetricCollection):
594kwargs["keep_base"] = False
595if _TORCHMETRICS_GREATER_EQUAL_0_9_1 and dictionary._enable_compute_groups:
596kwargs["copy_state"] = False
597
598for k, v in dictionary.items(**kwargs):
599self.log(
600name=k,
601value=v,
602prog_bar=prog_bar,
603logger=logger,
604on_step=on_step,
605on_epoch=on_epoch,
606reduce_fx=reduce_fx,
607enable_graph=enable_graph,
608sync_dist=sync_dist,
609sync_dist_group=sync_dist_group,
610add_dataloader_idx=add_dataloader_idx,
611batch_size=batch_size,
612rank_zero_only=rank_zero_only,
613)
614return None
615
616def _log_dict_through_fabric(
617self, dictionary: Union[Mapping[str, _METRIC], MetricCollection], logger: Optional[bool] = None
618) -> None:
619if logger is False:
620# Passing `logger=False` with Fabric does not make much sense because there is no other destination to
621# log to, but we support it in case the original code was written for Trainer use
622return
623
624if any(isinstance(v, dict) for v in dictionary.values()):
625raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged")
626for name, value in dictionary.items():
627apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor))
628
629assert self._fabric is not None
630self._fabric.log_dict(metrics=dictionary) # type: ignore[arg-type]
631
632@staticmethod
633def __check_not_nested(value: dict, name: str) -> None:
634# self-imposed restriction. for simplicity
635if any(isinstance(v, dict) for v in value.values()):
636raise ValueError(f"`self.log({name}, {value})` was called, but nested dictionaries cannot be logged")
637
638@staticmethod
639def __check_allowed(v: Any, name: str, value: Any) -> None:
640raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
641
642def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
643value = (
644value.clone().detach()
645if isinstance(value, Tensor)
646else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
647)
648if not torch.numel(value) == 1:
649raise ValueError(
650f"`self.log({name}, {value})` was called, but the tensor must have a single element."
651f" You can try doing `self.log({name}, {value}.mean())`"
652)
653value = value.squeeze()
654return value
655
656def all_gather(
657self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
658) -> Union[Tensor, Dict, List, Tuple]:
659r"""Gather tensors or collections of tensors from multiple processes.
660
661This method needs to be called on all processes and the tensors need to have the same shape across all
662processes, otherwise your program will stall forever.
663
664Args:
665data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
666group: the process group to gather results from. Defaults to all processes (world)
667sync_grads: flag that allows users to synchronize gradients for the all_gather operation
668
669Return:
670A tensor of shape (world_size, batch, ...), or if the input was a collection
671the output will also be a collection with tensors of this shape. For the special case where
672world_size is 1, no additional dimension is added to the tensor(s).
673
674"""
675group = group if group is not None else torch.distributed.group.WORLD
676all_gather = self.trainer.strategy.all_gather
677data = convert_to_tensors(data, device=self.device)
678return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
679
680@override
681def forward(self, *args: Any, **kwargs: Any) -> Any:
682r"""Same as :meth:`torch.nn.Module.forward`.
683
684Args:
685*args: Whatever you decide to pass into the forward method.
686**kwargs: Keyword arguments are also possible.
687
688Return:
689Your model's output
690
691"""
692return super().forward(*args, **kwargs)
693
694def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
695r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
696logger.
697
698Args:
699batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
700batch_idx: The index of this batch.
701dataloader_idx: The index of the dataloader that produced this batch.
702(only if multiple dataloaders used)
703
704Return:
705- :class:`~torch.Tensor` - The loss tensor
706- ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
707automatic optimization.
708- ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
709multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
710the loss is not required.
711
712In this step you'd normally do the forward pass and calculate the loss for a batch.
713You can also do fancier things like multiple forward passes or something model specific.
714
715Example::
716
717def training_step(self, batch, batch_idx):
718x, y, z = batch
719out = self.encoder(x)
720loss = self.loss(out, x)
721return loss
722
723To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
724
725.. code-block:: python
726
727def __init__(self):
728super().__init__()
729self.automatic_optimization = False
730
731
732# Multiple optimizers (e.g.: GANs)
733def training_step(self, batch, batch_idx):
734opt1, opt2 = self.optimizers()
735
736# do training_step with encoder
737...
738opt1.step()
739# do training_step with decoder
740...
741opt2.step()
742
743Note:
744When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
745normalized by ``accumulate_grad_batches`` internally.
746
747"""
748rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")
749
750def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
751r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or
752calculate anything of interest like accuracy.
753
754Args:
755batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
756batch_idx: The index of this batch.
757dataloader_idx: The index of the dataloader that produced this batch.
758(only if multiple dataloaders used)
759
760Return:
761- :class:`~torch.Tensor` - The loss tensor
762- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
763- ``None`` - Skip to the next batch.
764
765.. code-block:: python
766
767# if you have one val dataloader:
768def validation_step(self, batch, batch_idx): ...
769
770
771# if you have multiple val dataloaders:
772def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
773
774Examples::
775
776# CASE 1: A single validation dataset
777def validation_step(self, batch, batch_idx):
778x, y = batch
779
780# implement your own
781out = self(x)
782loss = self.loss(out, y)
783
784# log 6 example images
785# or generated text... or whatever
786sample_imgs = x[:6]
787grid = torchvision.utils.make_grid(sample_imgs)
788self.logger.experiment.add_image('example_images', grid, 0)
789
790# calculate acc
791labels_hat = torch.argmax(out, dim=1)
792val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
793
794# log the outputs!
795self.log_dict({'val_loss': loss, 'val_acc': val_acc})
796
797If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
798setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
799
800.. code-block:: python
801
802# CASE 2: multiple validation dataloaders
803def validation_step(self, batch, batch_idx, dataloader_idx=0):
804# dataloader_idx tells you which dataset this is.
805...
806
807Note:
808If you don't need to validate you don't need to implement this method.
809
810Note:
811When the :meth:`validation_step` is called, the model has been put in eval mode
812and PyTorch gradients have been disabled. At the end of validation,
813the model goes back to training mode and gradients are enabled.
814
815"""
816
817def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
818r"""Operates on a single batch of data from the test set. In this step you'd normally generate examples or
819calculate anything of interest such as accuracy.
820
821Args:
822batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
823batch_idx: The index of this batch.
824dataloader_idx: The index of the dataloader that produced this batch.
825(only if multiple dataloaders used)
826
827Return:
828- :class:`~torch.Tensor` - The loss tensor
829- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
830- ``None`` - Skip to the next batch.
831
832.. code-block:: python
833
834# if you have one test dataloader:
835def test_step(self, batch, batch_idx): ...
836
837
838# if you have multiple test dataloaders:
839def test_step(self, batch, batch_idx, dataloader_idx=0): ...
840
841Examples::
842
843# CASE 1: A single test dataset
844def test_step(self, batch, batch_idx):
845x, y = batch
846
847# implement your own
848out = self(x)
849loss = self.loss(out, y)
850
851# log 6 example images
852# or generated text... or whatever
853sample_imgs = x[:6]
854grid = torchvision.utils.make_grid(sample_imgs)
855self.logger.experiment.add_image('example_images', grid, 0)
856
857# calculate acc
858labels_hat = torch.argmax(out, dim=1)
859test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
860
861# log the outputs!
862self.log_dict({'test_loss': loss, 'test_acc': test_acc})
863
864If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
865setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
866
867.. code-block:: python
868
869# CASE 2: multiple test dataloaders
870def test_step(self, batch, batch_idx, dataloader_idx=0):
871# dataloader_idx tells you which dataset this is.
872...
873
874Note:
875If you don't need to test you don't need to implement this method.
876
877Note:
878When the :meth:`test_step` is called, the model has been put in eval mode and
879PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
880to training mode and gradients are enabled.
881
882"""
883
884def predict_step(self, *args: Any, **kwargs: Any) -> Any:
885"""Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
886:meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.
887
888The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
889to scale inference on multi-devices.
890
891To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
892callback to write the predictions to disk or database after each batch or on epoch end.
893
894The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
895based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
896or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
897
898Args:
899batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
900batch_idx: The index of this batch.
901dataloader_idx: The index of the dataloader that produced this batch.
902(only if multiple dataloaders used)
903
904Return:
905Predicted output (optional).
906
907Example ::
908
909class MyModel(LightningModule):
910
911def predict_step(self, batch, batch_idx, dataloader_idx=0):
912return self(batch)
913
914dm = ...
915model = MyModel()
916trainer = Trainer(accelerator="gpu", devices=2)
917predictions = trainer.predict(model, dm)
918
919"""
920# For backwards compatibility
921batch = kwargs.get("batch", args[0])
922return self(batch)
923
924def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
925"""Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets
926called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer's
927``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
928present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning will
929make sure :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
930
931Return:
932A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
933
934Example::
935
936def configure_callbacks(self):
937early_stop = EarlyStopping(monitor="val_acc", mode="max")
938checkpoint = ModelCheckpoint(monitor="val_loss")
939return [early_stop, checkpoint]
940
941"""
942return []
943
944def configure_optimizers(self) -> OptimizerLRScheduler:
945r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one.
946But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in
947the manual optimization mode.
948
949Return:
950Any of these 6 options.
951
952- **Single optimizer**.
953- **List or Tuple** of optimizers.
954- **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers
955(or multiple ``lr_scheduler_config``).
956- **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"``
957key whose value is a single LR scheduler or ``lr_scheduler_config``.
958- **None** - Fit will run without any optimizer.
959
960The ``lr_scheduler_config`` is a dictionary which contains the scheduler and its associated configuration.
961The default configuration is shown below.
962
963.. code-block:: python
964
965lr_scheduler_config = {
966# REQUIRED: The scheduler instance
967"scheduler": lr_scheduler,
968# The unit of the scheduler's step size, could also be 'step'.
969# 'epoch' updates the scheduler on epoch end whereas 'step'
970# updates it after a optimizer update.
971"interval": "epoch",
972# How many epochs/steps should pass between calls to
973# `scheduler.step()`. 1 corresponds to updating the learning
974# rate after every epoch/step.
975"frequency": 1,
976# Metric to to monitor for schedulers like `ReduceLROnPlateau`
977"monitor": "val_loss",
978# If set to `True`, will enforce that the value specified 'monitor'
979# is available when the scheduler is updated, thus stopping
980# training if not found. If set to `False`, it will only produce a warning
981"strict": True,
982# If using the `LearningRateMonitor` callback to monitor the
983# learning rate progress, this keyword can be used to specify
984# a custom logged name
985"name": None,
986}
987
988When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the
989:class:`torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, Lightning requires that the
990``lr_scheduler_config`` contains the keyword ``"monitor"`` set to the metric name that the scheduler
991should be conditioned on.
992
993.. testcode::
994
995# The ReduceLROnPlateau scheduler requires a monitor
996def configure_optimizers(self):
997optimizer = Adam(...)
998return {
999"optimizer": optimizer,
1000"lr_scheduler": {
1001"scheduler": ReduceLROnPlateau(optimizer, ...),
1002"monitor": "metric_to_track",
1003"frequency": "indicates how often the metric is updated",
1004# If "monitor" references validation metrics, then "frequency" should be set to a
1005# multiple of "trainer.check_val_every_n_epoch".
1006},
1007}
1008
1009
1010# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
1011def configure_optimizers(self):
1012optimizer1 = Adam(...)
1013optimizer2 = SGD(...)
1014scheduler1 = ReduceLROnPlateau(optimizer1, ...)
1015scheduler2 = LambdaLR(optimizer2, ...)
1016return (
1017{
1018"optimizer": optimizer1,
1019"lr_scheduler": {
1020"scheduler": scheduler1,
1021"monitor": "metric_to_track",
1022},
1023},
1024{"optimizer": optimizer2, "lr_scheduler": scheduler2},
1025)
1026
1027Metrics can be made available to monitor by simply logging it using
1028``self.log('metric_to_track', metric_val)`` in your :class:`~lightning.pytorch.core.LightningModule`.
1029
1030Note:
1031Some things to know:
1032
1033- Lightning calls ``.backward()`` and ``.step()`` automatically in case of automatic optimization.
1034- If a learning rate scheduler is specified in ``configure_optimizers()`` with key
1035``"interval"`` (default "epoch") in the scheduler configuration, Lightning will call
1036the scheduler's ``.step()`` method automatically in case of automatic optimization.
1037- If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizer.
1038- If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you.
1039- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them
1040yourself.
1041- If you need to control how often the optimizer steps, override the :meth:`optimizer_step` hook.
1042
1043"""
1044rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
1045
1046def manual_backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
1047"""Call this directly from your :meth:`training_step` when doing optimizations manually. By using this,
1048Lightning can ensure that all the proper scaling gets applied when using mixed precision.
1049
1050See :ref:`manual optimization<common/optimization:Manual optimization>` for more examples.
1051
1052Example::
1053
1054def training_step(...):
1055opt = self.optimizers()
1056loss = ...
1057opt.zero_grad()
1058# automatically applies scaling, etc...
1059self.manual_backward(loss)
1060opt.step()
1061
1062Args:
1063loss: The tensor on which to compute gradients. Must have a graph attached.
1064*args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward`
1065**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
1066
1067"""
1068if self._fabric:
1069self._fabric.backward(loss, *args, **kwargs)
1070else:
1071self._verify_is_manual_optimization("manual_backward")
1072self.trainer.strategy.backward(loss, None, *args, **kwargs)
1073
1074def backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
1075"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own
1076implementation if you need to.
1077
1078Args:
1079loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here
1080holds the normalized value (scaled by 1 / accumulation steps).
1081
1082Example::
1083
1084def backward(self, loss):
1085loss.backward()
1086
1087"""
1088if self._fabric:
1089self._fabric.backward(loss, *args, **kwargs)
1090else:
1091loss.backward(*args, **kwargs)
1092
1093def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
1094"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
1095prevent dangling gradients in multiple-optimizer setup.
1096
1097It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.
1098
1099Args:
1100optimizer: The optimizer to toggle.
1101
1102"""
1103# Iterate over all optimizer parameters to preserve their `requires_grad` information
1104# in case these are pre-defined during `configure_optimizers`
1105param_requires_grad_state = {}
1106for opt in self.trainer.optimizers:
1107for group in opt.param_groups:
1108for param in group["params"]:
1109# If a param already appear in param_requires_grad_state, continue
1110if param in param_requires_grad_state:
1111continue
1112param_requires_grad_state[param] = param.requires_grad
1113param.requires_grad = False
1114
1115# Then iterate over the current optimizer's parameters and set its `requires_grad`
1116# properties accordingly
1117for group in optimizer.param_groups:
1118for param in group["params"]:
1119param.requires_grad = param_requires_grad_state[param]
1120self._param_requires_grad_state = param_requires_grad_state
1121
1122def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
1123"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`.
1124
1125Args:
1126optimizer: The optimizer to untoggle.
1127
1128"""
1129for opt in self.trainer.optimizers:
1130if not (opt is optimizer or (isinstance(optimizer, LightningOptimizer) and opt is optimizer.optimizer)):
1131for group in opt.param_groups:
1132for param in group["params"]:
1133if param in self._param_requires_grad_state:
1134param.requires_grad = self._param_requires_grad_state[param]
1135# save memory
1136self._param_requires_grad_state = {}
1137
1138def clip_gradients(
1139self,
1140optimizer: Optimizer,
1141gradient_clip_val: Optional[Union[int, float]] = None,
1142gradient_clip_algorithm: Optional[str] = None,
1143) -> None:
1144"""Handles gradient clipping internally.
1145
1146Note:
1147- Do not override this method. If you want to customize gradient clipping, consider using
1148:meth:`configure_gradient_clipping` method.
1149- For manual optimization (``self.automatic_optimization = False``), if you want to use
1150gradient clipping, consider calling
1151``self.clip_gradients(opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm")``
1152manually in the training step.
1153
1154Args:
1155optimizer: Current optimizer being used.
1156gradient_clip_val: The value at which to clip gradients.
1157gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
1158to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
1159
1160"""
1161
1162if self.fabric is not None:
1163self.fabric.clip_gradients(
1164self,
1165optimizer,
1166clip_val=gradient_clip_val if gradient_clip_algorithm == GradClipAlgorithmType.VALUE else None,
1167max_norm=None if gradient_clip_algorithm == GradClipAlgorithmType.VALUE else gradient_clip_val,
1168)
1169return
1170
1171if gradient_clip_val is None:
1172gradient_clip_val = self.trainer.gradient_clip_val or 0.0
1173elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
1174raise MisconfigurationException(
1175f"You have set `Trainer(gradient_clip_val={self.trainer.gradient_clip_val!r})`"
1176f" and have passed `clip_gradients(gradient_clip_val={gradient_clip_val!r})`."
1177" Please use only one of them."
1178)
1179
1180if gradient_clip_algorithm is None:
1181gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm"
1182else:
1183gradient_clip_algorithm = gradient_clip_algorithm.lower()
1184if (
1185self.trainer.gradient_clip_algorithm is not None
1186and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
1187):
1188raise MisconfigurationException(
1189f"You have set `Trainer(gradient_clip_algorithm={self.trainer.gradient_clip_algorithm.value!r})`"
1190f" and have passed `clip_gradients(gradient_clip_algorithm={gradient_clip_algorithm!r})"
1191" Please use only one of them."
1192)
1193
1194if not isinstance(gradient_clip_val, (int, float)):
1195raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
1196
1197if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
1198raise MisconfigurationException(
1199f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid."
1200f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
1201)
1202
1203gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
1204self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
1205
1206def configure_gradient_clipping(
1207self,
1208optimizer: Optimizer,
1209gradient_clip_val: Optional[Union[int, float]] = None,
1210gradient_clip_algorithm: Optional[str] = None,
1211) -> None:
1212"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
1213
1214Args:
1215optimizer: Current optimizer being used.
1216gradient_clip_val: The value at which to clip gradients. By default, value passed in Trainer
1217will be available here.
1218gradient_clip_algorithm: The gradient clipping algorithm to use. By default, value
1219passed in Trainer will be available here.
1220
1221Example::
1222
1223def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
1224# Implement your own custom logic to clip gradients
1225# You can call `self.clip_gradients` with your settings:
1226self.clip_gradients(
1227optimizer,
1228gradient_clip_val=gradient_clip_val,
1229gradient_clip_algorithm=gradient_clip_algorithm
1230)
1231
1232"""
1233self.clip_gradients(
1234optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
1235)
1236
1237def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Optional[Any]) -> None:
1238r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1239each scheduler. By default, Lightning calls ``step()`` and as shown in the example for each scheduler based on
1240its ``interval``.
1241
1242Args:
1243scheduler: Learning rate scheduler.
1244metric: Value of the monitor used for schedulers like ``ReduceLROnPlateau``.
1245
1246Examples::
1247
1248# DEFAULT
1249def lr_scheduler_step(self, scheduler, metric):
1250if metric is None:
1251scheduler.step()
1252else:
1253scheduler.step(metric)
1254
1255# Alternative way to update schedulers if it requires an epoch value
1256def lr_scheduler_step(self, scheduler, metric):
1257scheduler.step(epoch=self.current_epoch)
1258
1259"""
1260if metric is None:
1261scheduler.step() # type: ignore[call-arg]
1262else:
1263scheduler.step(metric)
1264
1265def optimizer_step(
1266self,
1267epoch: int,
1268batch_idx: int,
1269optimizer: Union[Optimizer, LightningOptimizer],
1270optimizer_closure: Optional[Callable[[], Any]] = None,
1271) -> None:
1272r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1273the optimizer.
1274
1275By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example.
1276This method (and ``zero_grad()``) won't be called during the accumulation phase when
1277``Trainer(accumulate_grad_batches != 1)``. Overriding this hook has no benefit with manual optimization.
1278
1279Args:
1280epoch: Current epoch
1281batch_idx: Index of current batch
1282optimizer: A PyTorch optimizer
1283optimizer_closure: The optimizer closure. This closure must be executed as it includes the
1284calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``.
1285
1286Examples::
1287
1288def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
1289# Add your custom logic to run directly before `optimizer.step()`
1290
1291optimizer.step(closure=optimizer_closure)
1292
1293# Add your custom logic to run directly after `optimizer.step()`
1294
1295"""
1296optimizer.step(closure=optimizer_closure)
1297
1298def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None:
1299"""Override this method to change the default behaviour of ``optimizer.zero_grad()``.
1300
1301Args:
1302epoch: Current epoch
1303batch_idx: Index of current batch
1304optimizer: A PyTorch optimizer
1305
1306Examples::
1307
1308# DEFAULT
1309def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
1310optimizer.zero_grad()
1311
1312# Set gradients to `None` instead of zero to improve performance (not required on `torch>=2.0.0`).
1313def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
1314optimizer.zero_grad(set_to_none=True)
1315
1316See :meth:`torch.optim.Optimizer.zero_grad` for the explanation of the above example.
1317
1318"""
1319optimizer.zero_grad()
1320
1321def freeze(self) -> None:
1322r"""Freeze all params for inference.
1323
1324Example::
1325
1326model = MyLightningModule(...)
1327model.freeze()
1328
1329"""
1330for param in self.parameters():
1331param.requires_grad = False
1332
1333self.eval()
1334
1335def unfreeze(self) -> None:
1336"""Unfreeze all parameters for training.
1337
1338.. code-block:: python
1339
1340model = MyLightningModule(...)
1341model.unfreeze()
1342
1343"""
1344for param in self.parameters():
1345param.requires_grad = True
1346
1347self.train()
1348
1349def _verify_is_manual_optimization(self, fn_name: str) -> None:
1350if self.automatic_optimization:
1351raise MisconfigurationException(
1352f"to use {fn_name}, please disable automatic optimization:"
1353" set model property `automatic_optimization` as False"
1354)
1355
1356@torch.no_grad()
1357def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
1358"""Saves the model in ONNX format.
1359
1360Args:
1361file_path: The path of the file the onnx model should be saved to.
1362input_sample: An input for tracing. Default: None (Use self.example_input_array)
1363**kwargs: Will be passed to torch.onnx.export function.
1364
1365Example::
1366
1367class SimpleModel(LightningModule):
1368def __init__(self):
1369super().__init__()
1370self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1371
1372def forward(self, x):
1373return torch.relu(self.l1(x.view(x.size(0), -1)
1374
1375model = SimpleModel()
1376input_sample = torch.randn(1, 64)
1377model.to_onnx("export.onnx", input_sample, export_params=True)
1378
1379"""
1380if _TORCH_GREATER_EQUAL_2_0 and not _ONNX_AVAILABLE:
1381raise ModuleNotFoundError(
1382f"`torch>=2.0` requires `onnx` to be installed to use `{type(self).__name__}.to_onnx()`"
1383)
1384
1385mode = self.training
1386
1387if input_sample is None:
1388if self.example_input_array is None:
1389raise ValueError(
1390"Could not export to ONNX since neither `input_sample` nor"
1391" `model.example_input_array` attribute is set."
1392)
1393input_sample = self.example_input_array
1394
1395input_sample = self._on_before_batch_transfer(input_sample)
1396input_sample = self._apply_batch_transfer_handler(input_sample)
1397
1398torch.onnx.export(self, input_sample, file_path, **kwargs)
1399self.train(mode)
1400
1401@torch.no_grad()
1402def to_torchscript(
1403self,
1404file_path: Optional[Union[str, Path]] = None,
1405method: Optional[str] = "script",
1406example_inputs: Optional[Any] = None,
1407**kwargs: Any,
1408) -> Union[ScriptModule, Dict[str, ScriptModule]]:
1409"""By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing,
1410please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is
1411provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that are
1412scripted you should override this method. In case you want to return multiple modules, we recommend using a
1413dictionary.
1414
1415Args:
1416file_path: Path where to save the torchscript. Default: None (no file saved).
1417method: Whether to use TorchScript's script or trace method. Default: 'script'
1418example_inputs: An input to be used to do tracing when method is set to 'trace'.
1419Default: None (uses :attr:`example_input_array`)
1420**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
1421:func:`torch.jit.trace` function.
1422
1423Note:
1424- Requires the implementation of the
1425:meth:`~lightning.pytorch.core.LightningModule.forward` method.
1426- The exported script will be set to evaluation mode.
1427- It is recommended that you install the latest supported version of PyTorch
1428to use this feature without limitations. See also the :mod:`torch.jit`
1429documentation for supported features.
1430
1431Example::
1432
1433class SimpleModel(LightningModule):
1434def __init__(self):
1435super().__init__()
1436self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1437
1438def forward(self, x):
1439return torch.relu(self.l1(x.view(x.size(0), -1)))
1440
1441model = SimpleModel()
1442model.to_torchscript(file_path="model.pt")
1443
1444torch.jit.save(model.to_torchscript(
1445file_path="model_trace.pt", method='trace', example_inputs=torch.randn(1, 64))
1446)
1447
1448Return:
1449This LightningModule as a torchscript, regardless of whether `file_path` is
1450defined or not.
1451
1452"""
1453mode = self.training
1454
1455if method == "script":
1456with _jit_is_scripting():
1457torchscript_module = torch.jit.script(self.eval(), **kwargs)
1458elif method == "trace":
1459# if no example inputs are provided, try to see if model has example_input_array set
1460if example_inputs is None:
1461if self.example_input_array is None:
1462raise ValueError(
1463"Choosing method=`trace` requires either `example_inputs`"
1464" or `model.example_input_array` to be defined."
1465)
1466example_inputs = self.example_input_array
1467
1468# automatically send example inputs to the right device and use trace
1469example_inputs = self._on_before_batch_transfer(example_inputs)
1470example_inputs = self._apply_batch_transfer_handler(example_inputs)
1471with _jit_is_scripting():
1472torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1473else:
1474raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")
1475
1476self.train(mode)
1477
1478if file_path is not None:
1479fs = get_filesystem(file_path)
1480with fs.open(file_path, "wb") as f:
1481torch.jit.save(torchscript_module, f)
1482
1483return torchscript_module
1484
1485@_restricted_classmethod
1486def load_from_checkpoint(
1487cls,
1488checkpoint_path: Union[_PATH, IO],
1489map_location: _MAP_LOCATION_TYPE = None,
1490hparams_file: Optional[_PATH] = None,
1491strict: Optional[bool] = None,
1492**kwargs: Any,
1493) -> Self:
1494r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
1495passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
1496
1497Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
1498
1499Args:
1500checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
1501map_location:
1502If your checkpoint saved a GPU model and you now load on CPUs
1503or a different number of GPUs, use this to map to the new setup.
1504The behaviour is the same as in :func:`torch.load`.
1505hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure
1506as in this example::
1507
1508drop_prob: 0.2
1509dataloader:
1510batch_size: 32
1511
1512You most likely won't need this since Lightning will always save the hyperparameters
1513to the checkpoint.
1514However, if your checkpoint weights don't have the hyperparameters saved,
1515use this method to pass in a ``.yaml`` file with the hparams you'd like to use.
1516These will be converted into a :class:`~dict` and passed into your
1517:class:`LightningModule` for use.
1518
1519If your model's ``hparams`` argument is :class:`~argparse.Namespace`
1520and ``.yaml`` file has hierarchical structure, you need to refactor your model to treat
1521``hparams`` as :class:`~dict`.
1522strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
1523returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is
1524set, in which case it defaults to the value of ``LightningModule.strict_loading``.
1525\**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
1526hyperparameter values.
1527
1528Return:
1529:class:`LightningModule` instance with loaded weights and hyperparameters (if available).
1530
1531Note:
1532``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
1533**class** to call it instead of the :class:`LightningModule` instance, or a
1534``TypeError`` will be raised.
1535
1536Note:
1537To ensure all layers can be loaded from the checkpoint, this function will call
1538:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` directly after instantiating the
1539model if this hook is overridden in your LightningModule. However, note that ``load_from_checkpoint`` does
1540not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this
1541case, consider loading through the Trainer via ``.fit(ckpt_path=...)``.
1542
1543Example::
1544
1545# load weights without mapping ...
1546model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
1547
1548# or load weights mapping all weights from GPU 1 to GPU 0 ...
1549map_location = {'cuda:1':'cuda:0'}
1550model = MyLightningModule.load_from_checkpoint(
1551'path/to/checkpoint.ckpt',
1552map_location=map_location
1553)
1554
1555# or load weights and hyperparameters from separate files.
1556model = MyLightningModule.load_from_checkpoint(
1557'path/to/checkpoint.ckpt',
1558hparams_file='/path/to/hparams_file.yaml'
1559)
1560
1561# override some of the params with new values
1562model = MyLightningModule.load_from_checkpoint(
1563PATH,
1564num_layers=128,
1565pretrained_ckpt_path=NEW_PATH,
1566)
1567
1568# predict
1569pretrained_model.eval()
1570pretrained_model.freeze()
1571y_hat = pretrained_model(x)
1572
1573"""
1574loaded = _load_from_checkpoint(
1575cls, # type: ignore[arg-type]
1576checkpoint_path,
1577map_location,
1578hparams_file,
1579strict,
1580**kwargs,
1581)
1582return cast(Self, loaded)
1583
1584@override
1585def __getstate__(self) -> Dict[str, Any]:
1586state = dict(self.__dict__)
1587state["_trainer"] = None
1588return state
1589
1590def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
1591"""Adds ShardedTensor state dict hooks if ShardedTensors are supported.
1592
1593These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
1594
1595"""
1596if _TORCH_GREATER_EQUAL_2_1:
1597# ShardedTensor is deprecated in favor of DistributedTensor
1598return
1599if _IS_WINDOWS or not torch.distributed.is_available():
1600rank_zero_debug("Could not register sharded tensor state dict hooks")
1601return
1602
1603from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook
1604
1605self._register_state_dict_hook(state_dict_hook)
1606self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
1607
1608
1609@contextmanager
1610def _jit_is_scripting() -> Generator:
1611"""Workaround for https://github.com/pytorch/pytorch/issues/67146."""
1612LightningModule._jit_is_scripting = True
1613try:
1614yield
1615finally:
1616LightningModule._jit_is_scripting = False
1617
1618
1619class _TrainerFabricShim:
1620"""Intercepts attribute access on LightningModule's trainer reference and redirects it to the Fabric object."""
1621
1622def __init__(self, fabric: lf.Fabric) -> None:
1623super().__init__()
1624self._fabric = fabric
1625
1626def __getattr__(self, item: Any) -> Any:
1627try:
1628return getattr(self._fabric, item)
1629except AttributeError:
1630raise AttributeError(
1631f"Your LightningModule code tried to access `self.trainer.{item}` but this attribute is not available"
1632f" when using Fabric with a LightningModule."
1633)
1634