pytorch-lightning

Форк
0
1633 строки · 69.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""The LightningModule - an nn.Module with many additional features."""
15

16
import logging
17
import numbers
18
import weakref
19
from contextlib import contextmanager
20
from pathlib import Path
21
from typing import (
22
    IO,
23
    Any,
24
    Callable,
25
    Dict,
26
    Generator,
27
    List,
28
    Literal,
29
    Mapping,
30
    Optional,
31
    Sequence,
32
    Tuple,
33
    Union,
34
    cast,
35
    overload,
36
)
37

38
import torch
39
from lightning_utilities.core.apply_func import apply_to_collection
40
from lightning_utilities.core.imports import RequirementCache
41
from torch import ScriptModule, Tensor
42
from torch.nn import Module
43
from torch.optim.optimizer import Optimizer
44
from torchmetrics import Metric, MetricCollection
45
from typing_extensions import Self, override
46

47
import lightning.fabric as lf
48
import lightning.pytorch as pl
49
from lightning.fabric.loggers import Logger as FabricLogger
50
from lightning.fabric.utilities.apply_func import convert_to_tensors
51
from lightning.fabric.utilities.cloud_io import get_filesystem
52
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
53
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
54
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
55
from lightning.fabric.wrappers import _FabricOptimizer
56
from lightning.pytorch.callbacks.callback import Callback
57
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
58
from lightning.pytorch.core.mixins import HyperparametersMixin
59
from lightning.pytorch.core.optimizer import LightningOptimizer
60
from lightning.pytorch.core.saving import _load_from_checkpoint
61
from lightning.pytorch.loggers import Logger
62
from lightning.pytorch.trainer import call
63
from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator
64
from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype
65
from lightning.pytorch.utilities import GradClipAlgorithmType
66
from lightning.pytorch.utilities.exceptions import MisconfigurationException
67
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
68
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
69
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_debug, rank_zero_warn
70
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
71
from lightning.pytorch.utilities.types import (
72
    _METRIC,
73
    STEP_OUTPUT,
74
    LRSchedulerPLType,
75
    LRSchedulerTypeUnion,
76
    OptimizerLRScheduler,
77
)
78

79
_ONNX_AVAILABLE = RequirementCache("onnx")
80

81
warning_cache = WarningCache()
82
log = logging.getLogger(__name__)
83

84
MODULE_OPTIMIZERS = Union[
85
    Optimizer, LightningOptimizer, _FabricOptimizer, List[Optimizer], List[LightningOptimizer], List[_FabricOptimizer]
86
]
87

88

89
class LightningModule(
90
    _DeviceDtypeModuleMixin,
91
    HyperparametersMixin,
92
    ModelHooks,
93
    DataHooks,
94
    CheckpointHooks,
95
    Module,
96
):
97
    # Below is for property support of JIT
98
    # since none of these are important when using JIT, we are going to ignore them.
99
    __jit_unused_properties__: List[str] = (
100
        [
101
            "example_input_array",
102
            "on_gpu",
103
            "current_epoch",
104
            "global_step",
105
            "global_rank",
106
            "local_rank",
107
            "logger",
108
            "loggers",
109
            "automatic_optimization",
110
            "trainer",
111
            "fabric",
112
            "strict_loading",
113
        ]
114
        + _DeviceDtypeModuleMixin.__jit_unused_properties__
115
        + HyperparametersMixin.__jit_unused_properties__
116
    )
117
    _jit_is_scripting = False
118

119
    CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
120
    CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
121
    CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
122

123
    def __init__(self, *args: Any, **kwargs: Any) -> None:
124
        super().__init__(*args, **kwargs)
125

126
        # pointer to the trainer object
127
        self._trainer: Optional["pl.Trainer"] = None
128

129
        # attributes that can be set by user
130
        self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None
131
        self._automatic_optimization: bool = True
132
        self._strict_loading: Optional[bool] = None
133

134
        # attributes used internally
135
        self._current_fx_name: Optional[str] = None
136
        self._param_requires_grad_state: Dict[str, bool] = {}
137
        self._metric_attributes: Optional[Dict[int, str]] = None
138
        self._register_sharded_tensor_state_dict_hooks_if_available()
139
        self._compiler_ctx: Optional[Dict[str, Any]] = None
140

141
        # attributes only used when using fabric
142
        self._fabric: Optional["lf.Fabric"] = None
143
        self._fabric_optimizers: List[_FabricOptimizer] = []
144

145
    @overload
146
    def optimizers(
147
        self, use_pl_optimizer: Literal[True] = True
148
    ) -> Union[LightningOptimizer, List[LightningOptimizer]]: ...
149

150
    @overload
151
    def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: ...
152

153
    @overload
154
    def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ...
155

156
    def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
157
        """Returns the optimizer(s) that are being used during training. Useful for manual optimization.
158

159
        Args:
160
            use_pl_optimizer: If ``True``, will wrap the optimizer(s) in a
161
                :class:`~lightning.pytorch.core.optimizer.LightningOptimizer` for automatic handling of precision,
162
                profiling, and counting of step calls for proper logging and checkpointing. It specifically wraps the
163
                ``step`` method and custom optimizers that don't have this method are not supported.
164

165
        Returns:
166
            A single optimizer, or a list of optimizers in case multiple ones are present.
167

168
        """
169
        if self._fabric:
170
            opts: MODULE_OPTIMIZERS = self._fabric_optimizers
171
        elif use_pl_optimizer:
172
            opts = self.trainer.strategy._lightning_optimizers
173
        else:
174
            opts = self.trainer.optimizers
175

176
        # single optimizer
177
        if (
178
            isinstance(opts, list)
179
            and len(opts) == 1
180
            and isinstance(opts[0], (Optimizer, LightningOptimizer, _FabricOptimizer))
181
        ):
182
            return opts[0]
183
        # multiple opts
184
        return opts
185

186
    def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]:
187
        """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization.
188

189
        Returns:
190
            A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no
191
            schedulers were returned in :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`.
192

193
        """
194
        if not self.trainer.lr_scheduler_configs:
195
            return None
196

197
        # ignore other keys "interval", "frequency", etc.
198
        lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs]
199

200
        # single scheduler
201
        if len(lr_schedulers) == 1:
202
            return lr_schedulers[0]
203

204
        # multiple schedulers
205
        return lr_schedulers
206

207
    @property
208
    def trainer(self) -> "pl.Trainer":
209
        if self._fabric is not None:
210
            return _TrainerFabricShim(fabric=self._fabric)  # type: ignore[return-value]
211
        if not self._jit_is_scripting and self._trainer is None:
212
            raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
213
        return self._trainer  # type: ignore[return-value]
214

215
    @trainer.setter
216
    def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
217
        for v in self.children():
218
            if isinstance(v, LightningModule):
219
                v.trainer = trainer  # type: ignore[assignment]
220
        # https://github.com/pytorch/pytorch/issues/95857
221
        if not _TORCH_GREATER_EQUAL_2_0 and trainer is not None and not isinstance(trainer, weakref.ProxyTypes):
222
            trainer = weakref.proxy(trainer)
223
        self._trainer = trainer
224

225
    @property
226
    def fabric(self) -> Optional["lf.Fabric"]:
227
        return self._fabric
228

229
    @fabric.setter
230
    def fabric(self, fabric: Optional["lf.Fabric"]) -> None:
231
        for v in self.children():
232
            if isinstance(v, LightningModule):
233
                v.fabric = fabric
234
        if fabric is not None and not isinstance(fabric, weakref.ProxyTypes):
235
            fabric = weakref.proxy(fabric)
236
        self._fabric = fabric
237

238
    @property
239
    def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]:
240
        """The example input array is a specification of what the module can consume in the :meth:`forward` method. The
241
        return type is interpreted as follows:
242

243
        -   Single tensor: It is assumed the model takes a single argument, i.e.,
244
            ``model.forward(model.example_input_array)``
245
        -   Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
246
            ``model.forward(*model.example_input_array)``
247
        -   Dict: The input array represents named keyword arguments, i.e.,
248
            ``model.forward(**model.example_input_array)``
249

250
        """
251
        return self._example_input_array
252

253
    @example_input_array.setter
254
    def example_input_array(self, example: Optional[Union[Tensor, Tuple, Dict]]) -> None:
255
        self._example_input_array = example
256

257
    @property
258
    def current_epoch(self) -> int:
259
        """The current epoch in the ``Trainer``, or 0 if not attached."""
260
        return self.trainer.current_epoch if self._trainer else 0
261

262
    @property
263
    def global_step(self) -> int:
264
        """Total training batches seen across all epochs.
265

266
        If no Trainer is attached, this propery is 0.
267

268
        """
269
        return self.trainer.global_step if self._trainer else 0
270

271
    @property
272
    def global_rank(self) -> int:
273
        """The index of the current process across all nodes and devices."""
274
        return self.trainer.global_rank if self._trainer else 0
275

276
    @property
277
    def local_rank(self) -> int:
278
        """The index of the current process within a single node."""
279
        return self.trainer.local_rank if self._trainer else 0
280

281
    @property
282
    def on_gpu(self) -> bool:
283
        """Returns ``True`` if this model is currently located on a GPU.
284

285
        Useful to set flags around the LightningModule for different CPU vs GPU behavior.
286

287
        """
288
        return self.device.type == "cuda"
289

290
    @property
291
    def automatic_optimization(self) -> bool:
292
        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293
        return self._automatic_optimization
294

295
    @automatic_optimization.setter
296
    def automatic_optimization(self, automatic_optimization: bool) -> None:
297
        self._automatic_optimization = automatic_optimization
298

299
    @property
300
    def strict_loading(self) -> bool:
301
        """Determines how Lightning loads this model using `.load_state_dict(..., strict=model.strict_loading)`."""
302
        # We use None as the default internally to determine whether the user has set a value
303
        return self._strict_loading in (None, True)
304

305
    @strict_loading.setter
306
    def strict_loading(self, strict_loading: bool) -> None:
307
        self._strict_loading = strict_loading
308

309
    @property
310
    def logger(self) -> Optional[Union[Logger, FabricLogger]]:
311
        """Reference to the logger object in the Trainer."""
312
        if self._fabric is not None:
313
            return self._fabric.logger
314
        return self._trainer.logger if self._trainer is not None else None
315

316
    @property
317
    def loggers(self) -> Union[List[Logger], List[FabricLogger]]:
318
        """Reference to the list of loggers in the Trainer."""
319
        if self._fabric is not None:
320
            return self._fabric.loggers
321
        if self._trainer is not None:
322
            return self._trainer.loggers
323
        return []
324

325
    def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
326
        trainer = self._trainer
327
        if trainer:
328
            datahook_selector = trainer._data_connector._datahook_selector
329
            assert datahook_selector is not None
330
            obj = datahook_selector.get_instance(hook_name)
331
            if isinstance(obj, self.__class__):
332
                trainer_method = call._call_lightning_module_hook
333
            else:
334
                trainer_method = call._call_lightning_datamodule_hook
335

336
            return trainer_method(trainer, hook_name, *args)
337
        hook = getattr(self, hook_name)
338
        return hook(*args)
339

340
    def _on_before_batch_transfer(self, batch: Any, dataloader_idx: int = 0) -> Any:
341
        return self._call_batch_hook("on_before_batch_transfer", batch, dataloader_idx)
342

343
    def _apply_batch_transfer_handler(
344
        self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
345
    ) -> Any:
346
        device = device or self.device
347
        batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
348
        batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
349
        return batch
350

351
    def print(self, *args: Any, **kwargs: Any) -> None:
352
        r"""Prints only from process 0. Use this in any distributed mode to log only once.
353

354
        Args:
355
            *args: The thing to print. The same as for Python's built-in print function.
356
            **kwargs: The same as for Python's built-in print function.
357

358
        Example::
359

360
            def forward(self, x):
361
                self.print(x, 'in forward')
362

363
        """
364
        if self.trainer.is_global_zero:
365
            progress_bar = self.trainer.progress_bar_callback
366
            if progress_bar is not None and progress_bar.is_enabled:
367
                progress_bar.print(*args, **kwargs)
368
            else:
369
                print(*args, **kwargs)
370

371
    def log(
372
        self,
373
        name: str,
374
        value: _METRIC,
375
        prog_bar: bool = False,
376
        logger: Optional[bool] = None,
377
        on_step: Optional[bool] = None,
378
        on_epoch: Optional[bool] = None,
379
        reduce_fx: Union[str, Callable] = "mean",
380
        enable_graph: bool = False,
381
        sync_dist: bool = False,
382
        sync_dist_group: Optional[Any] = None,
383
        add_dataloader_idx: bool = True,
384
        batch_size: Optional[int] = None,
385
        metric_attribute: Optional[str] = None,
386
        rank_zero_only: bool = False,
387
    ) -> None:
388
        """Log a key, value pair.
389

390
        Example::
391

392
            self.log('train_loss', loss)
393

394
        The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.
395

396
        Args:
397
            name: key to log.
398
            value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
399
            prog_bar: if ``True`` logs to the progress bar.
400
            logger: if ``True`` logs to the logger.
401
            on_step: if ``True`` logs at this step. The default value is determined by the hook.
402
                See :ref:`extensions/logging:Automatic Logging` for details.
403
            on_epoch: if ``True`` logs epoch accumulated metrics. The default value is determined by the hook.
404
                See :ref:`extensions/logging:Automatic Logging` for details.
405
            reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
406
            enable_graph: if ``True``, will not auto detach the graph.
407
            sync_dist: if ``True``, reduces the metric across devices. Use with care as this may lead to a significant
408
                communication overhead.
409
            sync_dist_group: the DDP group to sync across.
410
            add_dataloader_idx: if ``True``, appends the index of the current dataloader to
411
                the name (when using multiple dataloaders). If False, user needs to give unique names for
412
                each dataloader to not mix the values.
413
            batch_size: Current batch_size. This will be directly inferred from the loaded batch,
414
                but for some data structures you might need to explicitly provide it.
415
            metric_attribute: To restore the metric state, Lightning requires the reference of the
416
                :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
417
            rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
418
                rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
419
                (e.g., early stopping). Warning: Improper use can lead to deadlocks! See
420
                :ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
421

422
        """
423
        if self._fabric is not None:
424
            self._log_dict_through_fabric(dictionary={name: value}, logger=logger)
425
            return
426

427
        # check for invalid values
428
        apply_to_collection(value, dict, self.__check_not_nested, name)
429
        apply_to_collection(
430
            value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor)
431
        )
432

433
        trainer = self._trainer
434
        if trainer is None:
435
            # not an error to support testing the `*_step` methods without a `Trainer` reference
436
            rank_zero_warn(
437
                "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
438
                " This is most likely because the model hasn't been passed to the `Trainer`"
439
            )
440
            return
441
        if trainer.barebones:
442
            rank_zero_warn(
443
                "You are trying to `self.log()` but `Trainer(barebones=True)` is configured."
444
                " Logging can impact raw speed so it is disabled under this setting."
445
            )
446
            return
447
        results = trainer._results
448
        if results is None:
449
            raise MisconfigurationException(
450
                "You are trying to `self.log()` but the loop's result collection is not registered"
451
                " yet. This is most likely because you are trying to log in a `predict` hook,"
452
                " but it doesn't support logging"
453
            )
454
        if self._current_fx_name is None:
455
            raise MisconfigurationException(
456
                "You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
457
            )
458

459
        on_step, on_epoch = _FxValidator.check_logging_and_get_default_levels(
460
            self._current_fx_name, on_step=on_step, on_epoch=on_epoch
461
        )
462

463
        # make sure user doesn't introduce logic for multi-dataloaders
464
        if "/dataloader_idx_" in name:
465
            raise MisconfigurationException(
466
                f"You called `self.log` with the key `{name}`"
467
                " but it should not contain information about `dataloader_idx`"
468
            )
469

470
        value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)
471

472
        if trainer._logger_connector.should_reset_tensors(self._current_fx_name):
473
            # if we started a new epoch (running its first batch) the hook name has changed
474
            # reset any tensors for the new hook name
475
            results.reset(metrics=False, fx=self._current_fx_name)
476

477
        if metric_attribute is None and isinstance(value, Metric):
478
            if self._metric_attributes is None:
479
                # compute once
480
                self._metric_attributes = {
481
                    id(module): name for name, module in self.named_modules() if isinstance(module, Metric)
482
                }
483
                if not self._metric_attributes:
484
                    raise MisconfigurationException(
485
                        "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
486
                        " You can fix this by setting an attribute for the metric in your `LightningModule`."
487
                    )
488
            # try to find the passed metric in the LightningModule
489
            metric_attribute = self._metric_attributes.get(id(value), None)
490
            if metric_attribute is None:
491
                raise MisconfigurationException(
492
                    "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
493
                    f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
494
                    f" of {list(self._metric_attributes.values())}"
495
                )
496

497
        if (
498
            trainer.training
499
            and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
500
            and batch_size is None
501
        ):
502
            raise MisconfigurationException(
503
                "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
504
            )
505

506
        if logger and trainer.logger is None:
507
            rank_zero_warn(
508
                f"You called `self.log({name!r}, ..., logger=True)` but have no logger configured. You can enable one"
509
                " by doing `Trainer(logger=ALogger(...))`"
510
            )
511
        if logger is None:
512
            # we could set false here if there's no configured logger, however, we still need to compute the "logged"
513
            # metrics anyway because that's what the evaluation loops use as return value
514
            logger = True
515

516
        results.log(
517
            self._current_fx_name,
518
            name,
519
            value,
520
            prog_bar=prog_bar,
521
            logger=logger,
522
            on_step=on_step,
523
            on_epoch=on_epoch,
524
            reduce_fx=reduce_fx,  # type: ignore[arg-type]
525
            enable_graph=enable_graph,
526
            add_dataloader_idx=add_dataloader_idx,
527
            batch_size=batch_size,
528
            sync_dist=sync_dist and trainer._accelerator_connector.is_distributed,
529
            sync_dist_fn=trainer.strategy.reduce,
530
            sync_dist_group=sync_dist_group,
531
            metric_attribute=metric_attribute,
532
            rank_zero_only=rank_zero_only,
533
        )
534

535
        trainer._logger_connector._current_fx = self._current_fx_name
536

537
    def log_dict(
538
        self,
539
        dictionary: Union[Mapping[str, _METRIC], MetricCollection],
540
        prog_bar: bool = False,
541
        logger: Optional[bool] = None,
542
        on_step: Optional[bool] = None,
543
        on_epoch: Optional[bool] = None,
544
        reduce_fx: Union[str, Callable] = "mean",
545
        enable_graph: bool = False,
546
        sync_dist: bool = False,
547
        sync_dist_group: Optional[Any] = None,
548
        add_dataloader_idx: bool = True,
549
        batch_size: Optional[int] = None,
550
        rank_zero_only: bool = False,
551
    ) -> None:
552
        """Log a dictionary of values at once.
553

554
        Example::
555

556
            values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
557
            self.log_dict(values)
558

559
        Args:
560
            dictionary: key value pairs.
561
                The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
562
            prog_bar: if ``True`` logs to the progress base.
563
            logger: if ``True`` logs to the logger.
564
            on_step: if ``True`` logs at this step.
565
                ``None`` auto-logs for training_step but not validation/test_step.
566
                The default value is determined by the hook.
567
                See :ref:`extensions/logging:Automatic Logging` for details.
568
            on_epoch: if ``True`` logs epoch accumulated metrics.
569
                ``None`` auto-logs for val/test step but not ``training_step``.
570
                The default value is determined by the hook.
571
                See :ref:`extensions/logging:Automatic Logging` for details.
572
            reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
573
            enable_graph: if ``True``, will not auto-detach the graph
574
            sync_dist: if ``True``, reduces the metric across GPUs/TPUs. Use with care as this may lead to a significant
575
                communication overhead.
576
            sync_dist_group: the ddp group to sync across.
577
            add_dataloader_idx: if ``True``, appends the index of the current dataloader to
578
                the name (when using multiple). If ``False``, user needs to give unique names for
579
                each dataloader to not mix values.
580
            batch_size: Current batch size. This will be directly inferred from the loaded batch,
581
                but some data structures might need to explicitly provide it.
582
            rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
583
                rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
584
                (e.g., early stopping). Warning: Improper use can lead to deadlocks! See
585
                :ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
586

587
        """
588
        if self._fabric is not None:
589
            return self._log_dict_through_fabric(dictionary=dictionary, logger=logger)
590

591
        kwargs: Dict[str, bool] = {}
592

593
        if isinstance(dictionary, MetricCollection):
594
            kwargs["keep_base"] = False
595
            if _TORCHMETRICS_GREATER_EQUAL_0_9_1 and dictionary._enable_compute_groups:
596
                kwargs["copy_state"] = False
597

598
        for k, v in dictionary.items(**kwargs):
599
            self.log(
600
                name=k,
601
                value=v,
602
                prog_bar=prog_bar,
603
                logger=logger,
604
                on_step=on_step,
605
                on_epoch=on_epoch,
606
                reduce_fx=reduce_fx,
607
                enable_graph=enable_graph,
608
                sync_dist=sync_dist,
609
                sync_dist_group=sync_dist_group,
610
                add_dataloader_idx=add_dataloader_idx,
611
                batch_size=batch_size,
612
                rank_zero_only=rank_zero_only,
613
            )
614
        return None
615

616
    def _log_dict_through_fabric(
617
        self, dictionary: Union[Mapping[str, _METRIC], MetricCollection], logger: Optional[bool] = None
618
    ) -> None:
619
        if logger is False:
620
            # Passing `logger=False` with Fabric does not make much sense because there is no other destination to
621
            # log to, but we support it in case the original code was written for Trainer use
622
            return
623

624
        if any(isinstance(v, dict) for v in dictionary.values()):
625
            raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged")
626
        for name, value in dictionary.items():
627
            apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor))
628

629
        assert self._fabric is not None
630
        self._fabric.log_dict(metrics=dictionary)  # type: ignore[arg-type]
631

632
    @staticmethod
633
    def __check_not_nested(value: dict, name: str) -> None:
634
        # self-imposed restriction. for simplicity
635
        if any(isinstance(v, dict) for v in value.values()):
636
            raise ValueError(f"`self.log({name}, {value})` was called, but nested dictionaries cannot be logged")
637

638
    @staticmethod
639
    def __check_allowed(v: Any, name: str, value: Any) -> None:
640
        raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
641

642
    def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
643
        value = (
644
            value.clone().detach()
645
            if isinstance(value, Tensor)
646
            else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
647
        )
648
        if not torch.numel(value) == 1:
649
            raise ValueError(
650
                f"`self.log({name}, {value})` was called, but the tensor must have a single element."
651
                f" You can try doing `self.log({name}, {value}.mean())`"
652
            )
653
        value = value.squeeze()
654
        return value
655

656
    def all_gather(
657
        self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
658
    ) -> Union[Tensor, Dict, List, Tuple]:
659
        r"""Gather tensors or collections of tensors from multiple processes.
660

661
        This method needs to be called on all processes and the tensors need to have the same shape across all
662
        processes, otherwise your program will stall forever.
663

664
        Args:
665
            data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
666
            group: the process group to gather results from. Defaults to all processes (world)
667
            sync_grads: flag that allows users to synchronize gradients for the all_gather operation
668

669
        Return:
670
            A tensor of shape (world_size, batch, ...), or if the input was a collection
671
            the output will also be a collection with tensors of this shape. For the special case where
672
            world_size is 1, no additional dimension is added to the tensor(s).
673

674
        """
675
        group = group if group is not None else torch.distributed.group.WORLD
676
        all_gather = self.trainer.strategy.all_gather
677
        data = convert_to_tensors(data, device=self.device)
678
        return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
679

680
    @override
681
    def forward(self, *args: Any, **kwargs: Any) -> Any:
682
        r"""Same as :meth:`torch.nn.Module.forward`.
683

684
        Args:
685
            *args: Whatever you decide to pass into the forward method.
686
            **kwargs: Keyword arguments are also possible.
687

688
        Return:
689
            Your model's output
690

691
        """
692
        return super().forward(*args, **kwargs)
693

694
    def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
695
        r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
696
        logger.
697

698
        Args:
699
            batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
700
            batch_idx: The index of this batch.
701
            dataloader_idx: The index of the dataloader that produced this batch.
702
                (only if multiple dataloaders used)
703

704
        Return:
705
            - :class:`~torch.Tensor` - The loss tensor
706
            - ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
707
              automatic optimization.
708
            - ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
709
              multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
710
              the loss is not required.
711

712
        In this step you'd normally do the forward pass and calculate the loss for a batch.
713
        You can also do fancier things like multiple forward passes or something model specific.
714

715
        Example::
716

717
            def training_step(self, batch, batch_idx):
718
                x, y, z = batch
719
                out = self.encoder(x)
720
                loss = self.loss(out, x)
721
                return loss
722

723
        To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
724

725
        .. code-block:: python
726

727
            def __init__(self):
728
                super().__init__()
729
                self.automatic_optimization = False
730

731

732
            # Multiple optimizers (e.g.: GANs)
733
            def training_step(self, batch, batch_idx):
734
                opt1, opt2 = self.optimizers()
735

736
                # do training_step with encoder
737
                ...
738
                opt1.step()
739
                # do training_step with decoder
740
                ...
741
                opt2.step()
742

743
        Note:
744
            When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
745
            normalized by ``accumulate_grad_batches`` internally.
746

747
        """
748
        rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")
749

750
    def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
751
        r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or
752
        calculate anything of interest like accuracy.
753

754
        Args:
755
            batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
756
            batch_idx: The index of this batch.
757
            dataloader_idx: The index of the dataloader that produced this batch.
758
                (only if multiple dataloaders used)
759

760
        Return:
761
            - :class:`~torch.Tensor` - The loss tensor
762
            - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
763
            - ``None`` - Skip to the next batch.
764

765
        .. code-block:: python
766

767
            # if you have one val dataloader:
768
            def validation_step(self, batch, batch_idx): ...
769

770

771
            # if you have multiple val dataloaders:
772
            def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
773

774
        Examples::
775

776
            # CASE 1: A single validation dataset
777
            def validation_step(self, batch, batch_idx):
778
                x, y = batch
779

780
                # implement your own
781
                out = self(x)
782
                loss = self.loss(out, y)
783

784
                # log 6 example images
785
                # or generated text... or whatever
786
                sample_imgs = x[:6]
787
                grid = torchvision.utils.make_grid(sample_imgs)
788
                self.logger.experiment.add_image('example_images', grid, 0)
789

790
                # calculate acc
791
                labels_hat = torch.argmax(out, dim=1)
792
                val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
793

794
                # log the outputs!
795
                self.log_dict({'val_loss': loss, 'val_acc': val_acc})
796

797
        If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
798
        setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
799

800
        .. code-block:: python
801

802
            # CASE 2: multiple validation dataloaders
803
            def validation_step(self, batch, batch_idx, dataloader_idx=0):
804
                # dataloader_idx tells you which dataset this is.
805
                ...
806

807
        Note:
808
            If you don't need to validate you don't need to implement this method.
809

810
        Note:
811
            When the :meth:`validation_step` is called, the model has been put in eval mode
812
            and PyTorch gradients have been disabled. At the end of validation,
813
            the model goes back to training mode and gradients are enabled.
814

815
        """
816

817
    def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
818
        r"""Operates on a single batch of data from the test set. In this step you'd normally generate examples or
819
        calculate anything of interest such as accuracy.
820

821
        Args:
822
            batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
823
            batch_idx: The index of this batch.
824
            dataloader_idx: The index of the dataloader that produced this batch.
825
                (only if multiple dataloaders used)
826

827
        Return:
828
            - :class:`~torch.Tensor` - The loss tensor
829
            - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
830
            - ``None`` - Skip to the next batch.
831

832
        .. code-block:: python
833

834
            # if you have one test dataloader:
835
            def test_step(self, batch, batch_idx): ...
836

837

838
            # if you have multiple test dataloaders:
839
            def test_step(self, batch, batch_idx, dataloader_idx=0): ...
840

841
        Examples::
842

843
            # CASE 1: A single test dataset
844
            def test_step(self, batch, batch_idx):
845
                x, y = batch
846

847
                # implement your own
848
                out = self(x)
849
                loss = self.loss(out, y)
850

851
                # log 6 example images
852
                # or generated text... or whatever
853
                sample_imgs = x[:6]
854
                grid = torchvision.utils.make_grid(sample_imgs)
855
                self.logger.experiment.add_image('example_images', grid, 0)
856

857
                # calculate acc
858
                labels_hat = torch.argmax(out, dim=1)
859
                test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
860

861
                # log the outputs!
862
                self.log_dict({'test_loss': loss, 'test_acc': test_acc})
863

864
        If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
865
        setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
866

867
        .. code-block:: python
868

869
            # CASE 2: multiple test dataloaders
870
            def test_step(self, batch, batch_idx, dataloader_idx=0):
871
                # dataloader_idx tells you which dataset this is.
872
                ...
873

874
        Note:
875
            If you don't need to test you don't need to implement this method.
876

877
        Note:
878
            When the :meth:`test_step` is called, the model has been put in eval mode and
879
            PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
880
            to training mode and gradients are enabled.
881

882
        """
883

884
    def predict_step(self, *args: Any, **kwargs: Any) -> Any:
885
        """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
886
        :meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.
887

888
        The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
889
        to scale inference on multi-devices.
890

891
        To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
892
        callback to write the predictions to disk or database after each batch or on epoch end.
893

894
        The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
895
        based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
896
        or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
897

898
        Args:
899
            batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
900
            batch_idx: The index of this batch.
901
            dataloader_idx: The index of the dataloader that produced this batch.
902
                (only if multiple dataloaders used)
903

904
        Return:
905
            Predicted output (optional).
906

907
        Example ::
908

909
            class MyModel(LightningModule):
910

911
                def predict_step(self, batch, batch_idx, dataloader_idx=0):
912
                    return self(batch)
913

914
            dm = ...
915
            model = MyModel()
916
            trainer = Trainer(accelerator="gpu", devices=2)
917
            predictions = trainer.predict(model, dm)
918

919
        """
920
        # For backwards compatibility
921
        batch = kwargs.get("batch", args[0])
922
        return self(batch)
923

924
    def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
925
        """Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets
926
        called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer's
927
        ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
928
        present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning will
929
        make sure :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
930

931
        Return:
932
            A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
933

934
        Example::
935

936
            def configure_callbacks(self):
937
                early_stop = EarlyStopping(monitor="val_acc", mode="max")
938
                checkpoint = ModelCheckpoint(monitor="val_loss")
939
                return [early_stop, checkpoint]
940

941
        """
942
        return []
943

944
    def configure_optimizers(self) -> OptimizerLRScheduler:
945
        r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one.
946
        But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in
947
        the manual optimization mode.
948

949
        Return:
950
            Any of these 6 options.
951

952
            - **Single optimizer**.
953
            - **List or Tuple** of optimizers.
954
            - **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers
955
              (or multiple ``lr_scheduler_config``).
956
            - **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"``
957
              key whose value is a single LR scheduler or ``lr_scheduler_config``.
958
            - **None** - Fit will run without any optimizer.
959

960
        The ``lr_scheduler_config`` is a dictionary which contains the scheduler and its associated configuration.
961
        The default configuration is shown below.
962

963
        .. code-block:: python
964

965
            lr_scheduler_config = {
966
                # REQUIRED: The scheduler instance
967
                "scheduler": lr_scheduler,
968
                # The unit of the scheduler's step size, could also be 'step'.
969
                # 'epoch' updates the scheduler on epoch end whereas 'step'
970
                # updates it after a optimizer update.
971
                "interval": "epoch",
972
                # How many epochs/steps should pass between calls to
973
                # `scheduler.step()`. 1 corresponds to updating the learning
974
                # rate after every epoch/step.
975
                "frequency": 1,
976
                # Metric to to monitor for schedulers like `ReduceLROnPlateau`
977
                "monitor": "val_loss",
978
                # If set to `True`, will enforce that the value specified 'monitor'
979
                # is available when the scheduler is updated, thus stopping
980
                # training if not found. If set to `False`, it will only produce a warning
981
                "strict": True,
982
                # If using the `LearningRateMonitor` callback to monitor the
983
                # learning rate progress, this keyword can be used to specify
984
                # a custom logged name
985
                "name": None,
986
            }
987

988
        When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the
989
        :class:`torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, Lightning requires that the
990
        ``lr_scheduler_config`` contains the keyword ``"monitor"`` set to the metric name that the scheduler
991
        should be conditioned on.
992

993
        .. testcode::
994

995
            # The ReduceLROnPlateau scheduler requires a monitor
996
            def configure_optimizers(self):
997
                optimizer = Adam(...)
998
                return {
999
                    "optimizer": optimizer,
1000
                    "lr_scheduler": {
1001
                        "scheduler": ReduceLROnPlateau(optimizer, ...),
1002
                        "monitor": "metric_to_track",
1003
                        "frequency": "indicates how often the metric is updated",
1004
                        # If "monitor" references validation metrics, then "frequency" should be set to a
1005
                        # multiple of "trainer.check_val_every_n_epoch".
1006
                    },
1007
                }
1008

1009

1010
            # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
1011
            def configure_optimizers(self):
1012
                optimizer1 = Adam(...)
1013
                optimizer2 = SGD(...)
1014
                scheduler1 = ReduceLROnPlateau(optimizer1, ...)
1015
                scheduler2 = LambdaLR(optimizer2, ...)
1016
                return (
1017
                    {
1018
                        "optimizer": optimizer1,
1019
                        "lr_scheduler": {
1020
                            "scheduler": scheduler1,
1021
                            "monitor": "metric_to_track",
1022
                        },
1023
                    },
1024
                    {"optimizer": optimizer2, "lr_scheduler": scheduler2},
1025
                )
1026

1027
        Metrics can be made available to monitor by simply logging it using
1028
        ``self.log('metric_to_track', metric_val)`` in your :class:`~lightning.pytorch.core.LightningModule`.
1029

1030
        Note:
1031
            Some things to know:
1032

1033
            - Lightning calls ``.backward()`` and ``.step()`` automatically in case of automatic optimization.
1034
            - If a learning rate scheduler is specified in ``configure_optimizers()`` with key
1035
              ``"interval"`` (default "epoch") in the scheduler configuration, Lightning will call
1036
              the scheduler's ``.step()`` method automatically in case of automatic optimization.
1037
            - If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizer.
1038
            - If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you.
1039
            - If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them
1040
              yourself.
1041
            - If you need to control how often the optimizer steps, override the :meth:`optimizer_step` hook.
1042

1043
        """
1044
        rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
1045

1046
    def manual_backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
1047
        """Call this directly from your :meth:`training_step` when doing optimizations manually. By using this,
1048
        Lightning can ensure that all the proper scaling gets applied when using mixed precision.
1049

1050
        See :ref:`manual optimization<common/optimization:Manual optimization>` for more examples.
1051

1052
        Example::
1053

1054
            def training_step(...):
1055
                opt = self.optimizers()
1056
                loss = ...
1057
                opt.zero_grad()
1058
                # automatically applies scaling, etc...
1059
                self.manual_backward(loss)
1060
                opt.step()
1061

1062
        Args:
1063
            loss: The tensor on which to compute gradients. Must have a graph attached.
1064
            *args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward`
1065
            **kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
1066

1067
        """
1068
        if self._fabric:
1069
            self._fabric.backward(loss, *args, **kwargs)
1070
        else:
1071
            self._verify_is_manual_optimization("manual_backward")
1072
            self.trainer.strategy.backward(loss, None, *args, **kwargs)
1073

1074
    def backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
1075
        """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own
1076
        implementation if you need to.
1077

1078
        Args:
1079
            loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here
1080
                holds the normalized value (scaled by 1 / accumulation steps).
1081

1082
        Example::
1083

1084
            def backward(self, loss):
1085
                loss.backward()
1086

1087
        """
1088
        if self._fabric:
1089
            self._fabric.backward(loss, *args, **kwargs)
1090
        else:
1091
            loss.backward(*args, **kwargs)
1092

1093
    def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
1094
        """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
1095
        prevent dangling gradients in multiple-optimizer setup.
1096

1097
        It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.
1098

1099
        Args:
1100
            optimizer: The optimizer to toggle.
1101

1102
        """
1103
        # Iterate over all optimizer parameters to preserve their `requires_grad` information
1104
        # in case these are pre-defined during `configure_optimizers`
1105
        param_requires_grad_state = {}
1106
        for opt in self.trainer.optimizers:
1107
            for group in opt.param_groups:
1108
                for param in group["params"]:
1109
                    # If a param already appear in param_requires_grad_state, continue
1110
                    if param in param_requires_grad_state:
1111
                        continue
1112
                    param_requires_grad_state[param] = param.requires_grad
1113
                    param.requires_grad = False
1114

1115
        # Then iterate over the current optimizer's parameters and set its `requires_grad`
1116
        # properties accordingly
1117
        for group in optimizer.param_groups:
1118
            for param in group["params"]:
1119
                param.requires_grad = param_requires_grad_state[param]
1120
        self._param_requires_grad_state = param_requires_grad_state
1121

1122
    def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
1123
        """Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`.
1124

1125
        Args:
1126
            optimizer: The optimizer to untoggle.
1127

1128
        """
1129
        for opt in self.trainer.optimizers:
1130
            if not (opt is optimizer or (isinstance(optimizer, LightningOptimizer) and opt is optimizer.optimizer)):
1131
                for group in opt.param_groups:
1132
                    for param in group["params"]:
1133
                        if param in self._param_requires_grad_state:
1134
                            param.requires_grad = self._param_requires_grad_state[param]
1135
        # save memory
1136
        self._param_requires_grad_state = {}
1137

1138
    def clip_gradients(
1139
        self,
1140
        optimizer: Optimizer,
1141
        gradient_clip_val: Optional[Union[int, float]] = None,
1142
        gradient_clip_algorithm: Optional[str] = None,
1143
    ) -> None:
1144
        """Handles gradient clipping internally.
1145

1146
        Note:
1147
            - Do not override this method. If you want to customize gradient clipping, consider using
1148
              :meth:`configure_gradient_clipping` method.
1149
            - For manual optimization (``self.automatic_optimization = False``), if you want to use
1150
              gradient clipping, consider calling
1151
              ``self.clip_gradients(opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm")``
1152
              manually in the training step.
1153

1154
        Args:
1155
            optimizer: Current optimizer being used.
1156
            gradient_clip_val: The value at which to clip gradients.
1157
            gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
1158
                to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
1159

1160
        """
1161

1162
        if self.fabric is not None:
1163
            self.fabric.clip_gradients(
1164
                self,
1165
                optimizer,
1166
                clip_val=gradient_clip_val if gradient_clip_algorithm == GradClipAlgorithmType.VALUE else None,
1167
                max_norm=None if gradient_clip_algorithm == GradClipAlgorithmType.VALUE else gradient_clip_val,
1168
            )
1169
            return
1170

1171
        if gradient_clip_val is None:
1172
            gradient_clip_val = self.trainer.gradient_clip_val or 0.0
1173
        elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
1174
            raise MisconfigurationException(
1175
                f"You have set `Trainer(gradient_clip_val={self.trainer.gradient_clip_val!r})`"
1176
                f" and have passed `clip_gradients(gradient_clip_val={gradient_clip_val!r})`."
1177
                " Please use only one of them."
1178
            )
1179

1180
        if gradient_clip_algorithm is None:
1181
            gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm"
1182
        else:
1183
            gradient_clip_algorithm = gradient_clip_algorithm.lower()
1184
            if (
1185
                self.trainer.gradient_clip_algorithm is not None
1186
                and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
1187
            ):
1188
                raise MisconfigurationException(
1189
                    f"You have set `Trainer(gradient_clip_algorithm={self.trainer.gradient_clip_algorithm.value!r})`"
1190
                    f" and have passed `clip_gradients(gradient_clip_algorithm={gradient_clip_algorithm!r})"
1191
                    " Please use only one of them."
1192
                )
1193

1194
        if not isinstance(gradient_clip_val, (int, float)):
1195
            raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
1196

1197
        if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
1198
            raise MisconfigurationException(
1199
                f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid."
1200
                f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
1201
            )
1202

1203
        gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
1204
        self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
1205

1206
    def configure_gradient_clipping(
1207
        self,
1208
        optimizer: Optimizer,
1209
        gradient_clip_val: Optional[Union[int, float]] = None,
1210
        gradient_clip_algorithm: Optional[str] = None,
1211
    ) -> None:
1212
        """Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
1213

1214
        Args:
1215
            optimizer: Current optimizer being used.
1216
            gradient_clip_val: The value at which to clip gradients. By default, value passed in Trainer
1217
                will be available here.
1218
            gradient_clip_algorithm: The gradient clipping algorithm to use. By default, value
1219
                passed in Trainer will be available here.
1220

1221
        Example::
1222

1223
            def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
1224
                # Implement your own custom logic to clip gradients
1225
                # You can call `self.clip_gradients` with your settings:
1226
                self.clip_gradients(
1227
                    optimizer,
1228
                    gradient_clip_val=gradient_clip_val,
1229
                    gradient_clip_algorithm=gradient_clip_algorithm
1230
                )
1231

1232
        """
1233
        self.clip_gradients(
1234
            optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
1235
        )
1236

1237
    def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Optional[Any]) -> None:
1238
        r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1239
        each scheduler. By default, Lightning calls ``step()`` and as shown in the example for each scheduler based on
1240
        its ``interval``.
1241

1242
        Args:
1243
            scheduler: Learning rate scheduler.
1244
            metric: Value of the monitor used for schedulers like ``ReduceLROnPlateau``.
1245

1246
        Examples::
1247

1248
            # DEFAULT
1249
            def lr_scheduler_step(self, scheduler, metric):
1250
                if metric is None:
1251
                    scheduler.step()
1252
                else:
1253
                    scheduler.step(metric)
1254

1255
            # Alternative way to update schedulers if it requires an epoch value
1256
            def lr_scheduler_step(self, scheduler, metric):
1257
                scheduler.step(epoch=self.current_epoch)
1258

1259
        """
1260
        if metric is None:
1261
            scheduler.step()  # type: ignore[call-arg]
1262
        else:
1263
            scheduler.step(metric)
1264

1265
    def optimizer_step(
1266
        self,
1267
        epoch: int,
1268
        batch_idx: int,
1269
        optimizer: Union[Optimizer, LightningOptimizer],
1270
        optimizer_closure: Optional[Callable[[], Any]] = None,
1271
    ) -> None:
1272
        r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1273
        the optimizer.
1274

1275
        By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example.
1276
        This method (and ``zero_grad()``) won't be called during the accumulation phase when
1277
        ``Trainer(accumulate_grad_batches != 1)``. Overriding this hook has no benefit with manual optimization.
1278

1279
        Args:
1280
            epoch: Current epoch
1281
            batch_idx: Index of current batch
1282
            optimizer: A PyTorch optimizer
1283
            optimizer_closure: The optimizer closure. This closure must be executed as it includes the
1284
                calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``.
1285

1286
        Examples::
1287

1288
            def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
1289
                # Add your custom logic to run directly before `optimizer.step()`
1290

1291
                optimizer.step(closure=optimizer_closure)
1292

1293
                # Add your custom logic to run directly after `optimizer.step()`
1294

1295
        """
1296
        optimizer.step(closure=optimizer_closure)
1297

1298
    def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None:
1299
        """Override this method to change the default behaviour of ``optimizer.zero_grad()``.
1300

1301
        Args:
1302
            epoch: Current epoch
1303
            batch_idx: Index of current batch
1304
            optimizer: A PyTorch optimizer
1305

1306
        Examples::
1307

1308
            # DEFAULT
1309
            def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
1310
                optimizer.zero_grad()
1311

1312
            # Set gradients to `None` instead of zero to improve performance (not required on `torch>=2.0.0`).
1313
            def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
1314
                optimizer.zero_grad(set_to_none=True)
1315

1316
        See :meth:`torch.optim.Optimizer.zero_grad` for the explanation of the above example.
1317

1318
        """
1319
        optimizer.zero_grad()
1320

1321
    def freeze(self) -> None:
1322
        r"""Freeze all params for inference.
1323

1324
        Example::
1325

1326
            model = MyLightningModule(...)
1327
            model.freeze()
1328

1329
        """
1330
        for param in self.parameters():
1331
            param.requires_grad = False
1332

1333
        self.eval()
1334

1335
    def unfreeze(self) -> None:
1336
        """Unfreeze all parameters for training.
1337

1338
        .. code-block:: python
1339

1340
            model = MyLightningModule(...)
1341
            model.unfreeze()
1342

1343
        """
1344
        for param in self.parameters():
1345
            param.requires_grad = True
1346

1347
        self.train()
1348

1349
    def _verify_is_manual_optimization(self, fn_name: str) -> None:
1350
        if self.automatic_optimization:
1351
            raise MisconfigurationException(
1352
                f"to use {fn_name}, please disable automatic optimization:"
1353
                " set model property `automatic_optimization` as False"
1354
            )
1355

1356
    @torch.no_grad()
1357
    def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
1358
        """Saves the model in ONNX format.
1359

1360
        Args:
1361
            file_path: The path of the file the onnx model should be saved to.
1362
            input_sample: An input for tracing. Default: None (Use self.example_input_array)
1363
            **kwargs: Will be passed to torch.onnx.export function.
1364

1365
        Example::
1366

1367
            class SimpleModel(LightningModule):
1368
                def __init__(self):
1369
                    super().__init__()
1370
                    self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1371

1372
                def forward(self, x):
1373
                    return torch.relu(self.l1(x.view(x.size(0), -1)
1374

1375
            model = SimpleModel()
1376
            input_sample = torch.randn(1, 64)
1377
            model.to_onnx("export.onnx", input_sample, export_params=True)
1378

1379
        """
1380
        if _TORCH_GREATER_EQUAL_2_0 and not _ONNX_AVAILABLE:
1381
            raise ModuleNotFoundError(
1382
                f"`torch>=2.0` requires `onnx` to be installed to use `{type(self).__name__}.to_onnx()`"
1383
            )
1384

1385
        mode = self.training
1386

1387
        if input_sample is None:
1388
            if self.example_input_array is None:
1389
                raise ValueError(
1390
                    "Could not export to ONNX since neither `input_sample` nor"
1391
                    " `model.example_input_array` attribute is set."
1392
                )
1393
            input_sample = self.example_input_array
1394

1395
        input_sample = self._on_before_batch_transfer(input_sample)
1396
        input_sample = self._apply_batch_transfer_handler(input_sample)
1397

1398
        torch.onnx.export(self, input_sample, file_path, **kwargs)
1399
        self.train(mode)
1400

1401
    @torch.no_grad()
1402
    def to_torchscript(
1403
        self,
1404
        file_path: Optional[Union[str, Path]] = None,
1405
        method: Optional[str] = "script",
1406
        example_inputs: Optional[Any] = None,
1407
        **kwargs: Any,
1408
    ) -> Union[ScriptModule, Dict[str, ScriptModule]]:
1409
        """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing,
1410
        please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is
1411
        provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that are
1412
        scripted you should override this method. In case you want to return multiple modules, we recommend using a
1413
        dictionary.
1414

1415
        Args:
1416
            file_path: Path where to save the torchscript. Default: None (no file saved).
1417
            method: Whether to use TorchScript's script or trace method. Default: 'script'
1418
            example_inputs: An input to be used to do tracing when method is set to 'trace'.
1419
              Default: None (uses :attr:`example_input_array`)
1420
            **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
1421
              :func:`torch.jit.trace` function.
1422

1423
        Note:
1424
            - Requires the implementation of the
1425
              :meth:`~lightning.pytorch.core.LightningModule.forward` method.
1426
            - The exported script will be set to evaluation mode.
1427
            - It is recommended that you install the latest supported version of PyTorch
1428
              to use this feature without limitations. See also the :mod:`torch.jit`
1429
              documentation for supported features.
1430

1431
        Example::
1432

1433
            class SimpleModel(LightningModule):
1434
                def __init__(self):
1435
                    super().__init__()
1436
                    self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1437

1438
                def forward(self, x):
1439
                    return torch.relu(self.l1(x.view(x.size(0), -1)))
1440

1441
            model = SimpleModel()
1442
            model.to_torchscript(file_path="model.pt")
1443

1444
            torch.jit.save(model.to_torchscript(
1445
                file_path="model_trace.pt", method='trace', example_inputs=torch.randn(1, 64))
1446
            )
1447

1448
        Return:
1449
            This LightningModule as a torchscript, regardless of whether `file_path` is
1450
            defined or not.
1451

1452
        """
1453
        mode = self.training
1454

1455
        if method == "script":
1456
            with _jit_is_scripting():
1457
                torchscript_module = torch.jit.script(self.eval(), **kwargs)
1458
        elif method == "trace":
1459
            # if no example inputs are provided, try to see if model has example_input_array set
1460
            if example_inputs is None:
1461
                if self.example_input_array is None:
1462
                    raise ValueError(
1463
                        "Choosing method=`trace` requires either `example_inputs`"
1464
                        " or `model.example_input_array` to be defined."
1465
                    )
1466
                example_inputs = self.example_input_array
1467

1468
            # automatically send example inputs to the right device and use trace
1469
            example_inputs = self._on_before_batch_transfer(example_inputs)
1470
            example_inputs = self._apply_batch_transfer_handler(example_inputs)
1471
            with _jit_is_scripting():
1472
                torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1473
        else:
1474
            raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")
1475

1476
        self.train(mode)
1477

1478
        if file_path is not None:
1479
            fs = get_filesystem(file_path)
1480
            with fs.open(file_path, "wb") as f:
1481
                torch.jit.save(torchscript_module, f)
1482

1483
        return torchscript_module
1484

1485
    @_restricted_classmethod
1486
    def load_from_checkpoint(
1487
        cls,
1488
        checkpoint_path: Union[_PATH, IO],
1489
        map_location: _MAP_LOCATION_TYPE = None,
1490
        hparams_file: Optional[_PATH] = None,
1491
        strict: Optional[bool] = None,
1492
        **kwargs: Any,
1493
    ) -> Self:
1494
        r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
1495
        passed to ``__init__``  in the checkpoint under ``"hyper_parameters"``.
1496

1497
        Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
1498

1499
        Args:
1500
            checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
1501
            map_location:
1502
                If your checkpoint saved a GPU model and you now load on CPUs
1503
                or a different number of GPUs, use this to map to the new setup.
1504
                The behaviour is the same as in :func:`torch.load`.
1505
            hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure
1506
                as in this example::
1507

1508
                    drop_prob: 0.2
1509
                    dataloader:
1510
                        batch_size: 32
1511

1512
                You most likely won't need this since Lightning will always save the hyperparameters
1513
                to the checkpoint.
1514
                However, if your checkpoint weights don't have the hyperparameters saved,
1515
                use this method to pass in a ``.yaml`` file with the hparams you'd like to use.
1516
                These will be converted into a :class:`~dict` and passed into your
1517
                :class:`LightningModule` for use.
1518

1519
                If your model's ``hparams`` argument is :class:`~argparse.Namespace`
1520
                and ``.yaml`` file has hierarchical structure, you need to refactor your model to treat
1521
                ``hparams`` as :class:`~dict`.
1522
            strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
1523
                returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is
1524
                set, in which case it defaults to the value of ``LightningModule.strict_loading``.
1525
            \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
1526
                hyperparameter values.
1527

1528
        Return:
1529
            :class:`LightningModule` instance with loaded weights and hyperparameters (if available).
1530

1531
        Note:
1532
            ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
1533
            **class** to call it instead of the :class:`LightningModule` instance, or a
1534
            ``TypeError`` will be raised.
1535

1536
        Note:
1537
            To ensure all layers can be loaded from the checkpoint, this function will call
1538
            :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` directly after instantiating the
1539
            model if this hook is overridden in your LightningModule. However, note that ``load_from_checkpoint`` does
1540
            not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this
1541
            case, consider loading through the Trainer via ``.fit(ckpt_path=...)``.
1542

1543
        Example::
1544

1545
            # load weights without mapping ...
1546
            model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
1547

1548
            # or load weights mapping all weights from GPU 1 to GPU 0 ...
1549
            map_location = {'cuda:1':'cuda:0'}
1550
            model = MyLightningModule.load_from_checkpoint(
1551
                'path/to/checkpoint.ckpt',
1552
                map_location=map_location
1553
            )
1554

1555
            # or load weights and hyperparameters from separate files.
1556
            model = MyLightningModule.load_from_checkpoint(
1557
                'path/to/checkpoint.ckpt',
1558
                hparams_file='/path/to/hparams_file.yaml'
1559
            )
1560

1561
            # override some of the params with new values
1562
            model = MyLightningModule.load_from_checkpoint(
1563
                PATH,
1564
                num_layers=128,
1565
                pretrained_ckpt_path=NEW_PATH,
1566
            )
1567

1568
            # predict
1569
            pretrained_model.eval()
1570
            pretrained_model.freeze()
1571
            y_hat = pretrained_model(x)
1572

1573
        """
1574
        loaded = _load_from_checkpoint(
1575
            cls,  # type: ignore[arg-type]
1576
            checkpoint_path,
1577
            map_location,
1578
            hparams_file,
1579
            strict,
1580
            **kwargs,
1581
        )
1582
        return cast(Self, loaded)
1583

1584
    @override
1585
    def __getstate__(self) -> Dict[str, Any]:
1586
        state = dict(self.__dict__)
1587
        state["_trainer"] = None
1588
        return state
1589

1590
    def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
1591
        """Adds ShardedTensor state dict hooks if ShardedTensors are supported.
1592

1593
        These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
1594

1595
        """
1596
        if _TORCH_GREATER_EQUAL_2_1:
1597
            # ShardedTensor is deprecated in favor of DistributedTensor
1598
            return
1599
        if _IS_WINDOWS or not torch.distributed.is_available():
1600
            rank_zero_debug("Could not register sharded tensor state dict hooks")
1601
            return
1602

1603
        from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook
1604

1605
        self._register_state_dict_hook(state_dict_hook)
1606
        self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
1607

1608

1609
@contextmanager
1610
def _jit_is_scripting() -> Generator:
1611
    """Workaround for https://github.com/pytorch/pytorch/issues/67146."""
1612
    LightningModule._jit_is_scripting = True
1613
    try:
1614
        yield
1615
    finally:
1616
        LightningModule._jit_is_scripting = False
1617

1618

1619
class _TrainerFabricShim:
1620
    """Intercepts attribute access on LightningModule's trainer reference and redirects it to the Fabric object."""
1621

1622
    def __init__(self, fabric: lf.Fabric) -> None:
1623
        super().__init__()
1624
        self._fabric = fabric
1625

1626
    def __getattr__(self, item: Any) -> Any:
1627
        try:
1628
            return getattr(self._fabric, item)
1629
        except AttributeError:
1630
            raise AttributeError(
1631
                f"Your LightningModule code tried to access `self.trainer.{item}` but this attribute is not available"
1632
                f" when using Fabric with a LightningModule."
1633
            )
1634

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.