pytorch-lightning
125 строк · 4.4 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import sys
16from typing import Any
17
18import lightning.pytorch as pl
19from lightning.fabric.strategies import _StrategyRegistry
20from lightning.pytorch.accelerators.xla import XLAAccelerator
21from lightning.pytorch.plugins.precision import XLAPrecision
22from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy
23from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation
24
25
26def _patch_sys_modules() -> None:
27self = sys.modules[__name__]
28sys.modules["lightning.pytorch.strategies.single_tpu"] = self
29sys.modules["lightning.pytorch.accelerators.tpu"] = self
30sys.modules["lightning.pytorch.plugins.precision.tpu"] = self
31sys.modules["lightning.pytorch.plugins.precision.tpu_bf16"] = self
32sys.modules["lightning.pytorch.plugins.precision.xlabf16"] = self
33
34
35class SingleTPUStrategy(SingleDeviceXLAStrategy):
36"""Legacy class.
37
38Use :class:`~lightning.pytorch.strategies.single_xla.SingleDeviceXLAStrategy` instead.
39
40"""
41
42def __init__(self, *args: Any, **kwargs: Any) -> None:
43rank_zero_deprecation("The 'single_tpu' strategy is deprecated. Use 'single_xla' instead.")
44super().__init__(*args, **kwargs)
45
46@classmethod
47def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
48if "single_tpu" not in strategy_registry:
49strategy_registry.register("single_tpu", cls, description="Legacy class. Use `single_xla` instead.")
50
51
52class TPUAccelerator(XLAAccelerator):
53"""Legacy class.
54
55Use :class:`~lightning.pytorch.accelerators.xla.XLAAccelerator` instead.
56
57"""
58
59def __init__(self, *args: Any, **kwargs: Any) -> None:
60rank_zero_deprecation(
61"The `TPUAccelerator` class is deprecated. Use `lightning.pytorch.accelerators.XLAAccelerator` instead."
62)
63super().__init__(*args, **kwargs)
64
65
66class TPUPrecisionPlugin(XLAPrecision):
67"""Legacy class.
68
69Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecision` instead.
70
71"""
72
73def __init__(self, *args: Any, **kwargs: Any) -> None:
74rank_zero_deprecation(
75"The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecision`"
76" instead."
77)
78super().__init__(precision="32-true")
79
80
81class TPUBf16PrecisionPlugin(XLAPrecision):
82"""Legacy class.
83
84Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
85
86"""
87
88def __init__(self, *args: Any, **kwargs: Any) -> None:
89rank_zero_deprecation(
90"The `TPUBf16PrecisionPlugin` class is deprecated. Use"
91" `lightning.pytorch.plugins.precision.XLAPrecision` instead."
92)
93super().__init__(precision="bf16-true")
94
95
96class XLABf16PrecisionPlugin(XLAPrecision):
97"""Legacy class.
98
99Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
100
101"""
102
103def __init__(self, *args: Any, **kwargs: Any) -> None:
104rank_zero_deprecation(
105"The `XLABf16PrecisionPlugin` class is deprecated. Use"
106" `lightning.pytorch.plugins.precision.XLAPrecision` instead."
107)
108super().__init__(precision="bf16-true")
109
110
111def _patch_classes() -> None:
112setattr(pl.strategies, "SingleTPUStrategy", SingleTPUStrategy)
113setattr(pl.accelerators, "TPUAccelerator", TPUAccelerator)
114setattr(pl.plugins, "TPUPrecisionPlugin", TPUPrecisionPlugin)
115setattr(pl.plugins.precision, "TPUPrecisionPlugin", TPUPrecisionPlugin)
116setattr(pl.plugins, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
117setattr(pl.plugins.precision, "TPUBf16PrecisionPlugin", TPUBf16PrecisionPlugin)
118setattr(pl.plugins, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
119setattr(pl.plugins.precision, "XLABf16PrecisionPlugin", XLABf16PrecisionPlugin)
120
121
122_patch_sys_modules()
123_patch_classes()
124
125SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry) # type: ignore[has-type]
126