pytorch-lightning
362 строки · 12.9 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import io
15import os
16from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
17
18import torch
19from torch import Tensor
20from torch.nn import Module
21from typing_extensions import override
22
23import lightning.pytorch as pl
24from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1
25from lightning.fabric.plugins import XLACheckpointIO
26from lightning.fabric.plugins.environments import XLAEnvironment
27from lightning.fabric.strategies import _StrategyRegistry
28from lightning.fabric.utilities.optimizer import _optimizers_to_device
29from lightning.fabric.utilities.types import _PATH, ReduceOp
30from lightning.pytorch.plugins import XLAPrecision
31from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
32from lightning.pytorch.strategies.ddp import DDPStrategy
33from lightning.pytorch.strategies.launchers.xla import _XLALauncher
34from lightning.pytorch.strategies.strategy import TBroadcast
35from lightning.pytorch.trainer.states import TrainerFn
36from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
37from lightning.pytorch.utilities.rank_zero import rank_zero_only
38
39if TYPE_CHECKING:
40from torch_xla.distributed.parallel_loader import MpDeviceLoader
41
42
43class XLAStrategy(DDPStrategy):
44"""Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn`
45method."""
46
47strategy_name = "xla"
48
49def __init__(
50self,
51accelerator: Optional["pl.accelerators.Accelerator"] = None,
52parallel_devices: Optional[List[torch.device]] = None,
53checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None,
54precision_plugin: Optional[XLAPrecision] = None,
55debug: bool = False,
56sync_module_states: bool = True,
57**_: Any,
58) -> None:
59if not _XLA_AVAILABLE:
60raise ModuleNotFoundError(str(_XLA_AVAILABLE))
61super().__init__(
62accelerator=accelerator,
63parallel_devices=parallel_devices,
64cluster_environment=XLAEnvironment(),
65checkpoint_io=checkpoint_io,
66precision_plugin=precision_plugin,
67start_method="fork",
68)
69self.debug = debug
70self._launched = False
71self._sync_module_states = sync_module_states
72
73@property
74@override
75def checkpoint_io(self) -> Union[XLACheckpointIO, _WrappingCheckpointIO]:
76plugin = self._checkpoint_io
77if plugin is not None:
78assert isinstance(plugin, (XLACheckpointIO, _WrappingCheckpointIO))
79return plugin
80return XLACheckpointIO()
81
82@checkpoint_io.setter
83@override
84def checkpoint_io(self, io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]]) -> None:
85if io is not None and not isinstance(io, (XLACheckpointIO, _WrappingCheckpointIO)):
86raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
87self._checkpoint_io = io
88
89@property
90@override
91def precision_plugin(self) -> XLAPrecision:
92plugin = self._precision_plugin
93if plugin is not None:
94assert isinstance(plugin, XLAPrecision)
95return plugin
96return XLAPrecision()
97
98@precision_plugin.setter
99@override
100def precision_plugin(self, precision_plugin: Optional[XLAPrecision]) -> None:
101if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecision):
102raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision_plugin}")
103self._precision_plugin = precision_plugin
104
105@property
106@override
107def root_device(self) -> torch.device:
108if not self._launched:
109raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
110import torch_xla.core.xla_model as xm
111
112return xm.xla_device()
113
114@property
115@override
116def global_rank(self) -> int:
117return super().global_rank if self._launched else 0
118
119@property
120@override
121def local_rank(self) -> int:
122return super().local_rank if self._launched else 0
123
124@property
125@override
126def node_rank(self) -> int:
127return super().node_rank if self._launched else 0
128
129@property
130@override
131def world_size(self) -> int:
132return super().world_size if self._launched else 1
133
134@override
135def _configure_launcher(self) -> None:
136self._launcher = _XLALauncher(self)
137
138@override
139def setup(self, trainer: "pl.Trainer") -> None:
140assert self.accelerator is not None
141self.accelerator.setup(trainer)
142
143if self.debug:
144os.environ["PT_XLA_DEBUG"] = "1"
145
146assert self.model is not None
147self.precision_plugin.convert_module(self.model)
148
149shared_params = find_shared_parameters(self.model)
150self.model_to_device()
151set_shared_parameters(self.model, shared_params)
152
153self.model = self._setup_model(self.model)
154
155if self._sync_module_states:
156if _XLA_GREATER_EQUAL_2_1:
157from torch_xla.core.xla_model import broadcast_master_param
158else:
159from torch_xla.experimental.pjrt import broadcast_master_param
160
161broadcast_master_param(self.model)
162
163if trainer.state.fn == TrainerFn.FITTING:
164self.setup_optimizers(trainer)
165self.setup_precision_plugin()
166if trainer.state.fn == TrainerFn.FITTING:
167_optimizers_to_device(self.optimizers, self.root_device)
168
169@override
170def _setup_model(self, model: Module) -> Module: # type: ignore
171return model
172
173@property
174@override
175def distributed_sampler_kwargs(self) -> Dict[str, int]:
176return {"num_replicas": self.world_size, "rank": self.global_rank}
177
178@override
179def process_dataloader(self, dataloader: object) -> "MpDeviceLoader":
180from torch_xla.distributed.parallel_loader import MpDeviceLoader
181
182if isinstance(dataloader, MpDeviceLoader):
183# dataloader is already wrapped by MpDeviceLoader
184return dataloader
185
186dataloader = MpDeviceLoader(dataloader, self.root_device)
187# Mimic interface to torch.utils.data.DataLoader
188dataloader.dataset = dataloader._loader.dataset
189dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
190return dataloader
191
192@override
193def configure_ddp(self) -> None:
194pass
195
196@override
197def model_to_device(self) -> None:
198assert self.model is not None
199self.model = self.model.to(self.root_device)
200
201@override
202def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
203if not self._launched:
204return
205
206import torch_xla.core.xla_model as xm
207
208if name is None:
209# `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
210name = ""
211xm.rendezvous(name)
212
213@override
214def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
215if not self._launched:
216return obj
217
218import torch_xla.core.xla_model as xm
219
220is_tensor = isinstance(obj, Tensor)
221if is_tensor:
222if obj.dim() == 0:
223obj = obj.unsqueeze(0)
224original_device = obj.device
225# XLA distributed requires that the data is on the XLA device
226obj = obj.to(self.root_device)
227else:
228# support for arbitrary pickle-ables
229buffer = io.BytesIO()
230torch.save(obj, buffer)
231obj = torch.tensor( # type: ignore[assignment]
232bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
233)
234
235obj = [obj]
236xm.collective_broadcast(obj, root_ordinal=src)
237obj = obj[0]
238
239if not is_tensor:
240# this will preserve the dtype and device of any tensors
241buffer = io.BytesIO(obj.cpu().byte().numpy())
242obj = torch.load(buffer)
243else:
244obj = obj.to(original_device)
245
246return obj
247
248@override
249def reduce(
250self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
251) -> Tensor:
252if not isinstance(output, Tensor):
253output = torch.tensor(output, device=self.root_device)
254
255invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
256invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
257if invalid_reduce_op or invalid_reduce_op_str:
258raise ValueError(
259"Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
260f" {reduce_op}"
261)
262
263import torch_xla.core.xla_model as xm
264
265output = xm.mesh_reduce("reduce", output, sum)
266
267if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
268output = output / self.world_size
269
270return output
271
272@override
273def setup_environment(self) -> None:
274self._launched = True
275super().setup_environment()
276
277@override
278def setup_distributed(self) -> None:
279assert self.parallel_devices is not None
280if len(self.parallel_devices) == 1:
281# spawning only 1 device with PjRT is not supported:
282# https://github.com/Lightning-AI/lightning/pull/17408#discussion_r1170671732
283raise NotImplementedError(
284"The `XLAStrategy` does not support running on a single device with the PjRT runtime."
285" Try using all devices or the `SingleDeviceXLAStrategy` strategy"
286)
287rank_zero_only.rank = self.global_rank
288
289@override
290def set_world_ranks(self) -> None:
291# accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned
292# processes (by the accelerator connector), we cannot run the code that would normally be here.
293# instead it's done in `setup_distributed`
294pass
295
296@override
297def save_checkpoint(
298self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
299) -> None:
300import torch_xla.core.xla_model as xm
301
302# sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
303xm.mark_step()
304# save on global rank zero only
305super().save_checkpoint(checkpoint, filepath, storage_options=storage_options)
306
307@override
308def remove_checkpoint(self, filepath: _PATH) -> None:
309"""Remove checkpoint filepath from the filesystem.
310
311Args:
312filepath: Path to checkpoint
313
314"""
315if self.local_rank == 0:
316self.checkpoint_io.remove_checkpoint(filepath)
317
318@override
319def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
320"""Function to gather a tensor from several distributed processes.
321
322Args:
323tensor: tensor to all-gather.
324group: unused.
325sync_grads: flag that allows users to synchronize gradients for the all-gather operation.
326Return:
327A tensor of shape (world_size, ...)
328
329"""
330if not self._launched:
331return tensor
332if not isinstance(tensor, Tensor):
333raise NotImplementedError(
334f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
335)
336if tensor.dim() == 0:
337tensor = tensor.unsqueeze(0)
338original_device = tensor.device
339tensor = tensor.to(self.root_device)
340
341import torch_xla.core.functions as xf
342import torch_xla.core.xla_model as xm
343
344tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
345tensor = tensor.to(original_device)
346return tensor
347
348@override
349def teardown(self) -> None:
350super().teardown()
351self._launched = False # after the Trainer finishes, we aren't inside the spawned region
352os.environ.pop("PT_XLA_DEBUG", None)
353
354@classmethod
355@override
356def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
357strategy_registry.register("xla_debug", cls, description="XLA strategy with `debug` as True", debug=True)
358strategy_registry.register(
359cls.strategy_name,
360cls,
361description=cls.__name__,
362)
363