pytorch-lightning

Форк
0
125 строк · 4.4 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import sys
16
from typing import Any
17

18
import lightning.pytorch as pl
19
from lightning.fabric.strategies import _StrategyRegistry
20
from lightning.pytorch.accelerators.xla import XLAAccelerator
21
from lightning.pytorch.plugins.precision import XLAPrecision
22
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy
23
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation
24

25

26
def _patch_sys_modules() -> None:
27
    self = sys.modules[__name__]
28
    sys.modules["lightning.pytorch.strategies.single_tpu"] = self
29
    sys.modules["lightning.pytorch.accelerators.tpu"] = self
30
    sys.modules["lightning.pytorch.plugins.precision.tpu"] = self
31
    sys.modules["lightning.pytorch.plugins.precision.tpu_bf16"] = self
32
    sys.modules["lightning.pytorch.plugins.precision.xlabf16"] = self
33

34

35
class SingleTPUStrategy(SingleDeviceXLAStrategy):
36
    """Legacy class.
37

38
    Use :class:`~lightning.pytorch.strategies.single_xla.SingleDeviceXLAStrategy` instead.
39

40
    """
41

42
    def __init__(self, *args: Any, **kwargs: Any) -> None:
43
        rank_zero_deprecation("The 'single_tpu' strategy is deprecated. Use 'single_xla' instead.")
44
        super().__init__(*args, **kwargs)
45

46
    @classmethod
47
    def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
48
        if "single_tpu" not in strategy_registry:
49
            strategy_registry.register("single_tpu", cls, description="Legacy class. Use `single_xla` instead.")
50

51

52
class TPUAccelerator(XLAAccelerator):
53
    """Legacy class.
54

55
    Use :class:`~lightning.pytorch.accelerators.xla.XLAAccelerator` instead.
56

57
    """
58

59
    def __init__(self, *args: Any, **kwargs: Any) -> None:
60
        rank_zero_deprecation(
61
            "The `TPUAccelerator` class is deprecated. Use `lightning.pytorch.accelerators.XLAAccelerator` instead."
62
        )
63
        super().__init__(*args, **kwargs)
64

65

66
class TPUPrecisionPlugin(XLAPrecision):
67
    """Legacy class.
68

69
    Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecision` instead.
70

71
    """
72

73
    def __init__(self, *args: Any, **kwargs: Any) -> None:
74
        rank_zero_deprecation(
75
            "The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecision`"
76
            " instead."
77
        )
78
        super().__init__(precision="32-true")
79

80

81
class TPUBf16PrecisionPlugin(XLAPrecision):
82
    """Legacy class.
83

84
    Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
85

86
    """
87

88
    def __init__(self, *args: Any, **kwargs: Any) -> None:
89
        rank_zero_deprecation(
90
            "The `TPUBf16PrecisionPlugin` class is deprecated. Use"
91
            " `lightning.pytorch.plugins.precision.XLAPrecision` instead."
92
        )
93
        super().__init__(precision="bf16-true")
94

95

96
class XLABf16PrecisionPlugin(XLAPrecision):
97
    """Legacy class.
98

99
    Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
100

101
    """
102

103
    def __init__(self, *args: Any, **kwargs: Any) -> None:
104
        rank_zero_deprecation(
105
            "The `XLABf16PrecisionPlugin` class is deprecated. Use"
106
            " `lightning.pytorch.plugins.precision.XLAPrecision` instead."
107
        )
108
        super().__init__(precision="bf16-true")
109

110

111
def _patch_classes() -> None:
112
    setattr(pl.strategies, "SingleTPUStrategy", SingleTPUStrategy)
113
    setattr(pl.accelerators, "TPUAccelerator", TPUAccelerator)
114
    setattr(pl.plugins, "TPUPrecisionPlugin", TPUPrecisionPlugin)
115
    setattr(pl.plugins.precision, "TPUPrecisionPlugin", TPUPrecisionPlugin)
116
    setattr(pl.plugins, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
117
    setattr(pl.plugins.precision, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
118
    setattr(pl.plugins, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
119
    setattr(pl.plugins.precision, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
120

121

122
_patch_sys_modules()
123
_patch_classes()
124

125
SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry)  # type: ignore[has-type]
126

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.