pytorch-lightning

Форк
0
417 строк · 17.9 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from contextlib import contextmanager
15
from dataclasses import fields
16
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, overload
17
from weakref import proxy
18

19
import torch
20
from torch import optim
21
from torch.optim import Optimizer
22
from typing_extensions import override
23

24
import lightning.pytorch as pl
25
from lightning.fabric.utilities.types import Optimizable, ReduceLROnPlateau, _Stateful
26
from lightning.pytorch.utilities.exceptions import MisconfigurationException
27
from lightning.pytorch.utilities.model_helpers import is_overridden
28
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
29
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
30
from lightning.pytorch.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple
31

32

33
def do_nothing_closure() -> None:
34
    return
35

36

37
class LightningOptimizer:
38
    """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across
39
    accelerators, AMP, accumulate_grad_batches.
40

41
    Note: The purpose of this wrapper is only to define new methods and redirect the `.step()` call. The internal
42
    state ``__dict__`` is not kept in sync with the internal state of the original optimizer, but the Trainer never
43
    relies on the internal state of the wrapper.
44

45
    """
46

47
    def __init__(self, optimizer: Optimizer):
48
        self._optimizer = optimizer
49
        self._strategy: Optional[pl.strategies.Strategy] = None
50
        # to inject logic around the optimizer step, particularly useful with manual optimization
51
        self._on_before_step = do_nothing_closure
52
        self._on_after_step = do_nothing_closure
53
        # imitate the class of the wrapped object to make isinstance checks work
54
        self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
55

56
    @property
57
    def optimizer(self) -> Optimizer:
58
        return self._optimizer
59

60
    @contextmanager
61
    def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
62
        """This function is just a helper for advanced users.
63

64
        Considering the current optimizer as A and all other optimizers as B.
65
        Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False.
66

67
        When performing gradient accumulation, there is no need to perform grad synchronization
68
        during the accumulation phase.
69
        Setting `sync_grad` to False will block this synchronization and improve performance.
70

71
        """
72
        # local import here to avoid circular import
73
        from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
74

75
        assert self._strategy is not None
76
        lightning_module = self._strategy.lightning_module
77
        assert lightning_module is not None
78
        with _block_parallel_sync_behavior(self._strategy, block=(not sync_grad)):
79
            lightning_module.toggle_optimizer(self)
80
            yield
81
            lightning_module.untoggle_optimizer(self)
82

83
    def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any:
84
        """Performs a single optimization step (parameter update).
85

86
        Args:
87
            closure: An optional optimizer closure.
88
            kwargs: Any additional arguments to the ``optimizer.step()`` call.
89

90
        Returns:
91
            The output from the step call, which is generally the output of the closure execution.
92

93
        Example::
94

95
            # Scenario for a GAN using manual optimization
96
            def training_step(self, batch, batch_idx):
97
                opt_gen, opt_dis = self.optimizers()
98

99
                ...
100

101
                # compute generator loss
102
                loss_gen = self.compute_generator_loss(...)
103
                # zero_grad needs to be called before backward
104
                opt_gen.zero_grad()
105
                self.manual_backward(loss_gen)
106
                opt_gen.step()
107

108
                # compute discriminator loss
109
                loss_dis = self.compute_discriminator_loss(...)
110

111
                # zero_grad needs to be called before backward
112
                opt_dis.zero_grad()
113
                self.manual_backward(loss_dis)
114
                opt_dis.step()
115

116

117
            # A more advanced example
118
            def training_step(self, batch, batch_idx):
119
                opt_gen, opt_dis = self.optimizers()
120

121
                ...
122
                accumulated_grad_batches = batch_idx % 2 == 0
123

124
                # compute generator loss
125
                def closure_gen():
126
                    loss_gen = self.compute_generator_loss(...)
127
                    self.manual_backward(loss_gen)
128
                    if accumulated_grad_batches:
129
                        opt_gen.zero_grad()
130

131
                with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
132
                    opt_gen.step(closure=closure_gen)
133

134
                def closure_dis():
135
                    loss_dis = self.compute_discriminator_loss(...)
136
                    self.manual_backward(loss_dis)
137
                    if accumulated_grad_batches:
138
                        opt_dis.zero_grad()
139

140
                with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
141
                    opt_dis.step(closure=closure_dis)
142

143
        """
144
        self._on_before_step()
145

146
        if closure is None:
147
            closure = do_nothing_closure
148
        elif not callable(closure):
149
            raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
150

151
        assert self._strategy is not None
152
        step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
153

154
        self._on_after_step()
155

156
        return step_output
157

158
    @classmethod
159
    def _to_lightning_optimizer(
160
        cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy"
161
    ) -> "LightningOptimizer":
162
        # the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
163
        # tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]
164
        lightning_optimizer = optimizer if isinstance(optimizer, LightningOptimizer) else cls(optimizer)
165
        lightning_optimizer._strategy = proxy(strategy)
166
        return lightning_optimizer
167

168
    def __getattr__(self, item: Any) -> Any:
169
        return getattr(self._optimizer, item)
170

171

172
def _init_optimizers_and_lr_schedulers(
173
    model: "pl.LightningModule",
174
) -> Tuple[List[Optimizer], List[LRSchedulerConfig]]:
175
    """Calls `LightningModule.configure_optimizers` and parses and validates the output."""
176
    from lightning.pytorch.trainer import call
177

178
    optim_conf = call._call_lightning_module_hook(model.trainer, "configure_optimizers", pl_module=model)
179

180
    if optim_conf is None:
181
        rank_zero_warn(
182
            "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
183
        )
184
        optim_conf = _MockOptimizer()
185

186
    optimizers, lr_schedulers, monitor = _configure_optimizers(optim_conf)
187
    lr_scheduler_configs = (
188
        _configure_schedulers_automatic_opt(lr_schedulers, monitor)
189
        if model.automatic_optimization
190
        else _configure_schedulers_manual_opt(lr_schedulers)
191
    )
192
    _validate_multiple_optimizers_support(optimizers, model)
193
    _validate_optimizers_attached(optimizers, lr_scheduler_configs)
194
    _validate_scheduler_api(lr_scheduler_configs, model)
195
    return optimizers, lr_scheduler_configs
196

197

198
def _configure_optimizers(
199
    optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple],
200
) -> Tuple[List, List, Optional[str]]:
201
    optimizers, lr_schedulers = [], []
202
    monitor = None
203

204
    # single output, single optimizer
205
    if isinstance(optim_conf, Optimizable):
206
        optimizers = [optim_conf]
207
    # two lists, optimizer + lr schedulers
208
    elif (
209
        isinstance(optim_conf, (list, tuple))
210
        and len(optim_conf) == 2
211
        and isinstance(optim_conf[0], list)
212
        and all(isinstance(opt, Optimizable) for opt in optim_conf[0])
213
    ):
214
        opt, sch = optim_conf
215
        optimizers = opt
216
        lr_schedulers = sch if isinstance(sch, list) else [sch]
217
    # single dictionary
218
    elif isinstance(optim_conf, dict):
219
        _validate_optim_conf(optim_conf)
220
        optimizers = [optim_conf["optimizer"]]
221
        monitor = optim_conf.get("monitor", None)
222
        lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
223
    # multiple dictionaries
224
    elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
225
        for opt_dict in optim_conf:
226
            _validate_optim_conf(opt_dict)
227
        optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
228
        scheduler_dict = lambda scheduler: dict(scheduler) if isinstance(scheduler, dict) else {"scheduler": scheduler}
229
        lr_schedulers = [
230
            scheduler_dict(opt_dict["lr_scheduler"]) for opt_dict in optim_conf if "lr_scheduler" in opt_dict
231
        ]
232
    # single list or tuple, multiple optimizer
233
    elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf):
234
        optimizers = list(optim_conf)
235
    # unknown configuration
236
    else:
237
        raise MisconfigurationException(
238
            "Unknown configuration for model optimizers."
239
            " Output from `model.configure_optimizers()` should be one of:\n"
240
            " * `Optimizer`\n"
241
            " * [`Optimizer`]\n"
242
            " * ([`Optimizer`], [`LRScheduler`])\n"
243
            ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n'
244
        )
245
    return optimizers, lr_schedulers, monitor
246

247

248
def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
249
    """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization."""
250
    lr_scheduler_configs = []
251
    for scheduler in schedulers:
252
        if isinstance(scheduler, dict):
253
            # check provided keys
254
            supported_keys = {field.name for field in fields(LRSchedulerConfig)}
255
            extra_keys = scheduler.keys() - supported_keys
256
            if extra_keys:
257
                rank_zero_warn(
258
                    f"Found unsupported keys in the lr scheduler dict: {extra_keys}."
259
                    " HINT: remove them from the output of `configure_optimizers`.",
260
                    category=RuntimeWarning,
261
                )
262
                scheduler = {k: v for k, v in scheduler.items() if k in supported_keys}
263
            if "scheduler" not in scheduler:
264
                raise MisconfigurationException(
265
                    'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
266
                )
267
            if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
268
                raise MisconfigurationException(
269
                    'The "interval" key in lr scheduler dict must be "step" or "epoch"'
270
                    f' but is "{scheduler["interval"]}"'
271
                )
272
            scheduler["reduce_on_plateau"] = scheduler.get(
273
                "reduce_on_plateau", isinstance(scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau)
274
            )
275
            if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
276
                raise MisconfigurationException(
277
                    "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
278
                    ' For example: {"optimizer": optimizer, "lr_scheduler":'
279
                    ' {"scheduler": scheduler, "monitor": "your_loss"}}'
280
                )
281
            is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
282
            if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
283
                rank_zero_warn(
284
                    "A `OneCycleLR` scheduler is using 'interval': 'epoch'."
285
                    " Are you sure you didn't mean 'interval': 'step'?",
286
                    category=RuntimeWarning,
287
                )
288
            config = LRSchedulerConfig(**scheduler)
289
        elif isinstance(scheduler, ReduceLROnPlateau):
290
            if monitor is None:
291
                raise MisconfigurationException(
292
                    "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
293
                    " scheduler is used. For example:"
294
                    ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
295
                )
296
            config = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor)
297
        else:
298
            config = LRSchedulerConfig(scheduler)
299
        lr_scheduler_configs.append(config)
300
    return lr_scheduler_configs
301

302

303
def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]:
304
    """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual
305
    optimization."""
306
    lr_scheduler_configs = []
307
    for scheduler in schedulers:
308
        if isinstance(scheduler, dict):
309
            # interval is not in this list even though the user needs to manually call the scheduler because
310
            # the `LearningRateMonitor` callback needs to check its value to know when to log the learning rate
311
            invalid_keys = {"reduce_on_plateau", "monitor", "strict"}
312
            keys_to_warn = [k for k in scheduler if k in invalid_keys]
313

314
            if keys_to_warn:
315
                rank_zero_warn(
316
                    f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
317
                    " You need to call `lr_scheduler.step()` manually in manual optimization.",
318
                    category=RuntimeWarning,
319
                )
320

321
            config = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys})
322
        else:
323
            config = LRSchedulerConfig(scheduler)
324
        lr_scheduler_configs.append(config)
325
    return lr_scheduler_configs
326

327

328
def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
329
    for config in lr_scheduler_configs:
330
        scheduler = config.scheduler
331
        if not isinstance(scheduler, _Stateful):
332
            raise TypeError(
333
                f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
334
                " It should have `state_dict` and `load_state_dict` methods defined."
335
            )
336

337
        if (
338
            not isinstance(scheduler, LRSchedulerTypeTuple)
339
            and not is_overridden("lr_scheduler_step", model)
340
            and model.automatic_optimization
341
        ):
342
            raise MisconfigurationException(
343
                f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"
344
                " API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"
345
                " you are using a custom LR scheduler."
346
            )
347

348

349
def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "pl.LightningModule") -> None:
350
    if is_param_in_hook_signature(model.training_step, "optimizer_idx", explicit=True):
351
        raise RuntimeError(
352
            "Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx`"
353
            " argument from `training_step`, set `self.automatic_optimization = False` and access your optimizers"
354
            " in `training_step` with `opt1, opt2, ... = self.optimizers()`."
355
        )
356
    if model.automatic_optimization and len(optimizers) > 1:
357
        raise RuntimeError(
358
            "Training with multiple optimizers is only supported with manual optimization. Set"
359
            " `self.automatic_optimization = False`, then access your optimizers in `training_step` with"
360
            " `opt1, opt2, ... = self.optimizers()`."
361
        )
362

363

364
def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None:
365
    for config in lr_scheduler_configs:
366
        if config.scheduler.optimizer not in optimizers:
367
            raise MisconfigurationException(
368
                "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
369
            )
370

371

372
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
373
    valid_keys = {"optimizer", "lr_scheduler", "monitor"}
374
    extra_keys = optim_conf.keys() - valid_keys
375
    if extra_keys:
376
        rank_zero_warn(
377
            f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
378
        )
379

380

381
class _MockOptimizer(Optimizer):
382
    """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
383
    :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`."""
384

385
    def __init__(self) -> None:
386
        super().__init__([torch.zeros(1)], {})
387

388
    @override
389
    def add_param_group(self, param_group: Dict[Any, Any]) -> None:
390
        pass  # Do Nothing
391

392
    @override
393
    def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
394
        pass  # Do Nothing
395

396
    @override
397
    def state_dict(self) -> Dict[str, Any]:
398
        return {}  # Return Empty
399

400
    @overload
401
    def step(self, closure: None = ...) -> None: ...
402

403
    @overload
404
    def step(self, closure: Callable[[], float]) -> float: ...
405

406
    @override
407
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
408
        if closure is not None:
409
            return closure()
410

411
    @override
412
    def zero_grad(self, set_to_none: Optional[bool] = True) -> None:
413
        pass  # Do Nothing
414

415
    @override
416
    def __repr__(self) -> str:
417
        return "No Optimizer"
418

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.