pytorch-lightning

Форк
0
650 строк · 25.1 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import logging
15
from abc import ABC, abstractmethod
16
from contextlib import contextmanager, nullcontext
17
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
18

19
import torch
20
from torch import Tensor
21
from torch.nn import Module
22
from torch.optim import Optimizer
23

24
import lightning.pytorch as pl
25
from lightning.fabric.plugins import CheckpointIO
26
from lightning.fabric.strategies import _StrategyRegistry
27
from lightning.fabric.utilities import move_data_to_device
28
from lightning.fabric.utilities.distributed import ReduceOp
29
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
30
from lightning.fabric.utilities.init import _EmptyInit
31
from lightning.fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device
32
from lightning.fabric.utilities.types import _PATH
33
from lightning.pytorch.core.optimizer import LightningOptimizer, _init_optimizers_and_lr_schedulers
34
from lightning.pytorch.plugins import TorchCheckpointIO
35
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
36
from lightning.pytorch.plugins.precision import Precision
37
from lightning.pytorch.strategies.launchers.launcher import _Launcher
38
from lightning.pytorch.trainer.states import TrainerFn
39
from lightning.pytorch.utilities.types import STEP_OUTPUT, LRSchedulerConfig
40

41
TBroadcast = TypeVar("TBroadcast")
42
TReduce = TypeVar("TReduce")
43

44
log = logging.getLogger(__name__)
45

46

47
class Strategy(ABC):
48
    """Base class for all strategies that change the behaviour of the training, validation and test- loop."""
49

50
    def __init__(
51
        self,
52
        accelerator: Optional["pl.accelerators.Accelerator"] = None,
53
        checkpoint_io: Optional[CheckpointIO] = None,
54
        precision_plugin: Optional[Precision] = None,
55
    ) -> None:
56
        self._accelerator: Optional["pl.accelerators.Accelerator"] = accelerator
57
        self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io
58
        self._precision_plugin: Optional[Precision] = None
59
        # Call the precision setter for input validation
60
        self.precision_plugin = precision_plugin  # type: ignore[assignment]
61
        self._lightning_module: Optional[pl.LightningModule] = None
62
        self._model: Optional[Module] = None
63
        self._launcher: Optional[_Launcher] = None
64
        self._forward_redirection: _ForwardRedirection = _ForwardRedirection()
65
        self._optimizers: List[Optimizer] = []
66
        self._lightning_optimizers: List[LightningOptimizer] = []
67
        self.lr_scheduler_configs: List[LRSchedulerConfig] = []
68

69
    @property
70
    def launcher(self) -> Optional[_Launcher]:
71
        return self._launcher
72

73
    @property
74
    def accelerator(self) -> Optional["pl.accelerators.Accelerator"]:
75
        return self._accelerator
76

77
    @accelerator.setter
78
    def accelerator(self, accelerator: "pl.accelerators.Accelerator") -> None:
79
        self._accelerator = accelerator
80

81
    @property
82
    def checkpoint_io(self) -> CheckpointIO:
83
        if self._checkpoint_io is None:
84
            self._checkpoint_io = TorchCheckpointIO()
85
        elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
86
            self._checkpoint_io.checkpoint_io = TorchCheckpointIO()
87

88
        return self._checkpoint_io
89

90
    @checkpoint_io.setter
91
    def checkpoint_io(self, io: CheckpointIO) -> None:
92
        self._checkpoint_io = io
93

94
    @property
95
    def precision_plugin(self) -> Precision:
96
        return self._precision_plugin if self._precision_plugin is not None else Precision()
97

98
    @precision_plugin.setter
99
    def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
100
        self._precision_plugin = precision_plugin
101

102
    @property
103
    def optimizers(self) -> List[Optimizer]:
104
        return self._optimizers
105

106
    @optimizers.setter
107
    def optimizers(self, optimizers: List[Optimizer]) -> None:
108
        self._optimizers = optimizers
109
        self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers]
110

111
    def connect(self, model: "pl.LightningModule") -> None:
112
        """Called by the Trainer to connect the strategy with the model."""
113
        # model conversions cannot be applied at this point because `LightningModule.{setup,configure_model}` haven't
114
        # run yet
115
        self._lightning_module = model
116
        self.model = model
117

118
    def _configure_launcher(self) -> None:
119
        """Attach the launcher based on Strategy."""
120

121
    def setup_environment(self) -> None:
122
        """Setup any processes or distributed connections.
123

124
        This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
125
        environment before setup is complete.
126

127
        """
128
        assert self.accelerator is not None
129
        self.accelerator.setup_device(self.root_device)
130

131
    def setup_optimizers(self, trainer: "pl.Trainer") -> None:
132
        """Creates optimizers and schedulers.
133

134
        Args:
135
            trainer: the Trainer, these optimizers should be connected to
136

137
        """
138
        assert self.lightning_module is not None
139
        self.optimizers, self.lr_scheduler_configs = _init_optimizers_and_lr_schedulers(self.lightning_module)
140

141
    def setup(self, trainer: "pl.Trainer") -> None:
142
        """Sets up the accelerator, plugins and initializes the optimizers (if needed).
143

144
        Args:
145
            trainer: the trainer instance
146

147
        """
148
        assert self.accelerator is not None
149
        self.accelerator.setup(trainer)
150

151
        assert self.model is not None
152
        # let the precision plugin convert the module here so that this strategy hook can decide the order
153
        # of operations
154
        self.model = self.precision_plugin.convert_module(self.model)
155
        self.model_to_device()
156
        self.model = self._setup_model(self.model)
157

158
        if trainer.state.fn == TrainerFn.FITTING:
159
            self.setup_optimizers(trainer)
160
        self.setup_precision_plugin()
161
        if trainer.state.fn == TrainerFn.FITTING:
162
            _optimizers_to_device(self.optimizers, self.root_device)
163

164
    def setup_precision_plugin(self) -> None:
165
        """Attaches the precision plugin to the strategy."""
166
        assert self.model is not None
167
        model, optimizers, lr_scheduler_configs = self.precision_plugin.connect(
168
            self.model, self.optimizers, self.lr_scheduler_configs
169
        )
170
        self.model = model
171
        self.optimizers = optimizers
172
        self.lr_scheduler_configs = lr_scheduler_configs
173

174
    def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
175
        """Returns state of an optimizer.
176

177
        Allows for syncing/collating optimizer state from processes in custom strategies.
178

179
        """
180
        if isinstance(optimizer, LightningOptimizer):
181
            optimizer = optimizer._optimizer
182

183
        if hasattr(optimizer, "consolidate_state_dict"):
184
            # there are optimizers like PyTorch's ZeroRedundancyOptimizer that shard their
185
            # states, and to avoid OOM we consolidate the full state on rank 0 only
186
            optimizer.consolidate_state_dict()
187
            return optimizer.state_dict() if self.is_global_zero else {}
188

189
        # for optimizers that are not sharded, we return the state dict on all ranks
190
        return optimizer.state_dict()
191

192
    def backward(
193
        self,
194
        closure_loss: Tensor,
195
        optimizer: Optional[Optimizer],
196
        *args: Any,
197
        **kwargs: Any,
198
    ) -> Tensor:
199
        r"""Forwards backward-calls to the precision plugin.
200

201
        Args:
202
            closure_loss: a tensor holding the loss value to backpropagate
203
            optimizer: An optional optimizer that gets passed down to the precision plugin's backward
204
            \*args: Positional arguments that get passed down to the precision plugin's backward, intended as arguments
205
                for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`.
206
            \**kwargs: Keyword arguments for the same purpose as ``*args``.
207

208
        """
209
        self.pre_backward(closure_loss)
210
        assert self.lightning_module is not None
211
        closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module)
212

213
        self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
214

215
        closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module)
216
        self.post_backward(closure_loss)
217

218
        return closure_loss
219

220
    def optimizer_step(
221
        self,
222
        optimizer: Optimizer,
223
        closure: Callable[[], Any],
224
        model: Optional[Union["pl.LightningModule", Module]] = None,
225
        **kwargs: Any,
226
    ) -> Any:
227
        r"""Performs the actual optimizer step.
228

229
        Args:
230
            optimizer: the optimizer performing the step
231
            closure: closure calculating the loss value
232
            model: reference to the model, optionally defining optimizer step related hooks
233
            \**kwargs: Keyword arguments to ``optimizer.step``
234

235
        """
236
        model = model or self.lightning_module
237
        # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
238
        assert isinstance(model, pl.LightningModule)
239
        return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
240

241
    def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
242
        """Setup a model and multiple optimizers together.
243

244
        The returned objects are expected to be in the same order they were passed in. The default implementation will
245
        call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
246

247
        """
248
        # TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324
249
        model = self._setup_model(model)
250
        optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
251
        return model, optimizers
252

253
    def _setup_model(self, model: Module) -> Module:
254
        """Performs setup for the model, e.g., by wrapping it by another class."""
255
        # TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324
256
        return model
257

258
    def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
259
        """Performs setup for the optimizer, e.g., by wrapping it by another class."""
260
        # TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324
261
        return optimizer
262

263
    def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
264
        """Moves the batch to the correct device.
265

266
        The returned batch is of the same type as the input batch, just
267
        having all tensors on the correct device.
268

269
        Args:
270
            batch: The batch of samples to move to the correct device
271
            device: The target device
272
            dataloader_idx: The index of the dataloader to which the batch belongs.
273

274
        """
275
        model = self.lightning_module
276
        device = device or self.root_device
277
        if model is not None:
278
            return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
279
        return move_data_to_device(batch, device)
280

281
    @property
282
    @abstractmethod
283
    def root_device(self) -> torch.device:
284
        """Returns the root device."""
285

286
    @abstractmethod
287
    def model_to_device(self) -> None:
288
        """Moves the model to the correct device."""
289

290
    @property
291
    @abstractmethod
292
    def is_global_zero(self) -> bool:
293
        """Whether the current process is the rank zero process not only on the local node, but for all nodes."""
294

295
    @abstractmethod
296
    def reduce(
297
        self,
298
        tensor: Union[Tensor, Any],
299
        group: Optional[Any] = None,
300
        reduce_op: Optional[Union[ReduceOp, str]] = "mean",
301
    ) -> Union[Tensor, Any]:
302
        """Reduces the given tensor (e.g. across GPUs/processes).
303

304
        Args:
305
            tensor: the tensor to sync and reduce
306
            group: the process group to reduce
307
            reduce_op: the reduction operation. Defaults to 'mean'.
308
                Can also be a string 'sum' or ReduceOp.
309

310
        """
311

312
    @abstractmethod
313
    def barrier(self, name: Optional[str] = None) -> None:
314
        """Synchronizes all processes which blocks processes until the whole group enters this function.
315

316
        Args:
317
            name: an optional name to pass into barrier.
318

319
        """
320

321
    @abstractmethod
322
    def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
323
        """Broadcasts an object to all processes.
324

325
        Args:
326
            obj: the object to broadcast
327
            src: source rank
328

329
        """
330

331
    @abstractmethod
332
    def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
333
        """Perform an all_gather on all processes.
334

335
        Args:
336
            tensor: the tensor to all_gather
337
            group: the process group to gather results from
338
            sync_grads: flag that allows users to synchronize gradients for all_gather op
339

340
        """
341

342
    def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
343
        """Reduce a boolean decision across all processes."""
344
        return decision
345

346
    def pre_backward(self, closure_loss: Tensor) -> None:
347
        """Run before precision plugin executes backward."""
348

349
    def post_backward(self, closure_loss: Tensor) -> None:
350
        """Run after precision plugin executes backward."""
351

352
    @property
353
    def model(self) -> Optional[Module]:
354
        """Returns the potentially wrapped LightningModule."""
355
        return self._model if self._model is not None else self._lightning_module
356

357
    @model.setter
358
    def model(self, new_model: Optional[Module]) -> None:
359
        self._model = new_model
360

361
    @property
362
    def lightning_module(self) -> Optional["pl.LightningModule"]:
363
        """Returns the pure LightningModule without potential wrappers."""
364
        return self._lightning_module
365

366
    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
367
        torch.cuda.empty_cache()
368
        return self.checkpoint_io.load_checkpoint(checkpoint_path)
369

370
    def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
371
        assert self.lightning_module is not None
372
        self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
373

374
    def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
375
        optimizer_states = checkpoint["optimizer_states"]
376
        for optimizer, opt_state in zip(self.optimizers, optimizer_states):
377
            optimizer.load_state_dict(opt_state)
378
            _optimizer_to_device(optimizer, self.root_device)
379

380
    def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
381
        """The actual training step.
382

383
        See :meth:`~lightning.pytorch.core.LightningModule.training_step` for more details
384

385
        """
386
        assert self.lightning_module is not None
387
        assert self.model is not None
388
        with self.precision_plugin.train_step_context():
389
            if self.model != self.lightning_module:
390
                return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
391
            return self.lightning_module.training_step(*args, **kwargs)
392

393
    def post_training_step(self) -> None:
394
        """This hook is deprecated.
395

396
        Override :meth:`training_step` instead.
397

398
        """
399
        pass
400

401
    def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
402
        """The actual validation step.
403

404
        See :meth:`~lightning.pytorch.core.LightningModule.validation_step` for more details
405

406
        """
407
        assert self.lightning_module is not None
408
        assert self.model is not None
409
        with self.precision_plugin.val_step_context():
410
            if self.model != self.lightning_module:
411
                return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
412
            return self.lightning_module.validation_step(*args, **kwargs)
413

414
    def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
415
        """The actual test step.
416

417
        See :meth:`~lightning.pytorch.core.LightningModule.test_step` for more details
418

419
        """
420
        assert self.lightning_module is not None
421
        assert self.model is not None
422
        with self.precision_plugin.test_step_context():
423
            if self.model != self.lightning_module:
424
                return self._forward_redirection(self.model, self.lightning_module, "test_step", *args, **kwargs)
425
            return self.lightning_module.test_step(*args, **kwargs)
426

427
    def predict_step(self, *args: Any, **kwargs: Any) -> Any:
428
        """The actual predict step.
429

430
        See :meth:`~lightning.pytorch.core.LightningModule.predict_step` for more details
431

432
        """
433
        assert self.lightning_module is not None
434
        assert self.model is not None
435
        with self.precision_plugin.predict_step_context():
436
            if self.model != self.lightning_module:
437
                return self._forward_redirection(self.model, self.lightning_module, "predict_step", *args, **kwargs)
438
            return self.lightning_module.predict_step(*args, **kwargs)
439

440
    def process_dataloader(self, dataloader: object) -> object:
441
        """Wraps the dataloader if necessary.
442

443
        Args:
444
            dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
445

446
        """
447
        return dataloader
448

449
    @property
450
    def restore_checkpoint_after_setup(self) -> bool:
451
        """Override to delay restoring from checkpoint till after the setup phase has completed. This is useful when
452
        the strategy requires all the setup hooks to run before loading checkpoint.
453

454
        Returns:
455
            If ``True``, restore checkpoint after strategy setup.
456

457
        """
458
        return False
459

460
    @property
461
    def lightning_restore_optimizer(self) -> bool:
462
        """Override to disable Lightning restoring optimizers/schedulers.
463

464
        This is useful for strategies which manage restoring optimizers/schedulers.
465

466
        """
467
        return True
468

469
    @property
470
    def handles_gradient_accumulation(self) -> bool:
471
        """Whether the strategy handles gradient accumulation internally."""
472
        return False
473

474
    def lightning_module_state_dict(self) -> Dict[str, Any]:
475
        """Returns model state."""
476
        assert self.lightning_module is not None
477
        return self.lightning_module.state_dict()
478

479
    def save_checkpoint(
480
        self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
481
    ) -> None:
482
        """Save model/training states as a checkpoint file through state-dump and file-write.
483

484
        Args:
485
            checkpoint: dict containing model and trainer state
486
            filepath: write-target file's path
487
            storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
488

489
        """
490
        if self.is_global_zero:
491
            self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
492

493
    def remove_checkpoint(self, filepath: _PATH) -> None:
494
        """Remove checkpoint filepath from the filesystem.
495

496
        Args:
497
            filepath: Path to checkpoint
498

499
        """
500
        if self.is_global_zero:
501
            self.checkpoint_io.remove_checkpoint(filepath)
502

503
    @contextmanager
504
    def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
505
        """Controls how tensors get created (device, dtype).
506

507
        Args:
508
            empty_init: Whether to initialize the model with empty weights (uninitialized memory).
509
                If ``None``, the strategy will decide. Some strategies may not support all options.
510

511
        """
512
        device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
513
        empty_init_context = _EmptyInit(enabled=bool(empty_init))
514
        with empty_init_context, device_context, self.precision_plugin.tensor_init_context():
515
            yield
516

517
    @contextmanager
518
    def model_sharded_context(self) -> Generator[None, None, None]:
519
        """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard
520
        the model instantly, which is useful for extremely large models which can save memory and initialization time.
521

522
        Returns: Model parallel context.
523

524
        """
525
        yield
526

527
    def teardown(self) -> None:
528
        """This method is called to teardown the training process.
529

530
        It is the right place to release memory and free other resources.
531

532
        """
533
        _optimizers_to_device(self.optimizers, torch.device("cpu"))
534

535
        if self.lightning_module is not None:
536
            log.debug(f"{self.__class__.__name__}: moving model to CPU")
537
            self.lightning_module.cpu()
538
        self.precision_plugin.teardown()
539
        assert self.accelerator is not None
540
        self.accelerator.teardown()
541
        self.checkpoint_io.teardown()
542

543
    @classmethod
544
    def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
545
        pass
546

547
    def on_train_start(self) -> None:
548
        """Called when train begins."""
549
        pass
550

551
    def on_validation_start(self) -> None:
552
        """Called when validation begins."""
553
        pass
554

555
    def on_test_start(self) -> None:
556
        """Called when test begins."""
557
        pass
558

559
    def on_predict_start(self) -> None:
560
        """Called when predict begins."""
561
        pass
562

563
    def on_train_end(self) -> None:
564
        """Called when train ends."""
565
        pass
566

567
    def on_validation_end(self) -> None:
568
        """Called when validation ends."""
569
        pass
570

571
    def on_test_end(self) -> None:
572
        """Called when test end."""
573
        pass
574

575
    def on_predict_end(self) -> None:
576
        """Called when predict ends."""
577
        pass
578

579
    def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
580
        """Called in the training loop before anything happens for that batch."""
581
        pass
582

583
    def on_exception(self, exception: BaseException) -> None:
584
        """Called when the trainer execution is interrupted by an exception."""
585
        pass
586

587
    def _reset_optimizers_and_schedulers(self) -> None:
588
        self._optimizers = []
589
        self._lightning_optimizers = []
590
        self.lr_scheduler_configs = []
591

592
    def __getstate__(self) -> Dict:
593
        # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
594
        state = dict(vars(self))  # copy
595
        state["_lightning_optimizers"] = []
596
        return state
597

598
    def __setstate__(self, state: Dict) -> None:
599
        self.__dict__ = state
600
        self.optimizers = self.optimizers  # re-create the `_lightning_optimizers`
601

602

603
class _ForwardRedirection:
604
    """Implements the `forward-redirection`.
605

606
    A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
607

608
    """
609

610
    def __call__(
611
        self, wrapper_module: Module, original_module: "pl.LightningModule", method_name: str, *args: Any, **kwargs: Any
612
    ) -> STEP_OUTPUT:
613
        """Reroutes a method call through the `wrapper_module`'s `forward` method.
614

615
        Args:
616
            wrapper_module: The module that has `original_module` wrapped.
617
            original_module: The module that was wrapped inside `wrapper_module`.
618
            method_name: The name of the method that should be called on the `original_module` after inputs get
619
                redirected through the `wrapper_module`'s `forward` method.
620
            *args: The positional arguments to the method `method_name`. They will get passed to a patched
621
                `forward` method instead.
622
            **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
623
                `forward` method instead.
624

625
        """
626
        assert method_name != "forward"
627
        original_forward = original_module.forward
628

629
        def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
630
            # Unpatch ourselves immediately before calling the method `method_name`
631
            # because itself may want to call the real `forward`
632
            original_module.forward = original_forward  # type: ignore[method-assign]
633
            # Call the actual method e.g. `.training_step(...)`
634
            method = getattr(original_module, method_name)
635
            out = method(*_args, **_kwargs)
636
            self.on_after_inner_forward(wrapper_module, original_module)
637
            return out
638

639
        # Patch the original_module's forward so we can redirect the arguments back to the real method
640
        original_module.forward = wrapped_forward  # type: ignore[method-assign]
641

642
        wrapper_output = wrapper_module(*args, **kwargs)
643
        self.on_after_outer_forward(wrapper_module, original_module)
644
        return wrapper_output
645

646
    def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
647
        pass
648

649
    def on_after_outer_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
650
        pass
651

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.