pytorch-lightning
176 строк · 7.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import os15import queue16from typing import TYPE_CHECKING, Any, Callable, Optional, Union17
18import torch.multiprocessing as mp19from typing_extensions import override20
21from lightning.fabric.accelerators.xla import _XLA_AVAILABLE22from lightning.fabric.strategies.launchers.xla import _rank_teardown23from lightning.fabric.utilities import move_data_to_device24from lightning.pytorch.strategies.launchers.multiprocessing import (25_GlobalStateSnapshot,26_MultiProcessingLauncher,27_WorkerOutput,28)
29from lightning.pytorch.trainer.states import TrainerFn30from lightning.pytorch.utilities.rank_zero import rank_zero_debug31
32if TYPE_CHECKING:33import lightning.pytorch as pl34
35
36class _XLALauncher(_MultiProcessingLauncher):37r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the38end.
39
40The main process in which this launcher is invoked creates N so-called worker processes (using the
41`torch_xla` :func:`xmp.spawn`) that run the given function.
42Worker processes have a rank that ranges from 0 to N - 1.
43
44Note:
45- This launcher requires all objects to be pickleable.
46- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
47
48Args:
49strategy: A reference to the strategy that is used together with this launcher
50
51"""
52
53def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None:54if not _XLA_AVAILABLE:55raise ModuleNotFoundError(str(_XLA_AVAILABLE))56super().__init__(strategy=strategy, start_method="fork")57
58@property59@override60def is_interactive_compatible(self) -> bool:61return True62
63@override64def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:65"""Launches processes that run the given function in parallel.66
67The function is allowed to have a return value. However, when all processes join, only the return value
68of worker process 0 gets returned from this `launch` method in the main process.
69
70Arguments:
71function: The entry point for all launched processes.
72*args: Optional positional arguments to be passed to the given function.
73trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer` for which
74a selected set of attributes get restored in the main process after processes join.
75**kwargs: Optional keyword arguments to be passed to the given function.
76
77"""
78if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:79# resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction80raise NotImplementedError(81"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"82" supported. You can work around this by creating a new Trainer instance and passing the"83" `fit(ckpt_path=...)` argument."84)85
86# pjrt requires that the queue is serializable87return_queue = mp.Manager().Queue()88
89import torch_xla.distributed.xla_multiprocessing as xmp90
91spawn_kwargs = {}92nprocs = self._strategy.num_processes93if nprocs == 1:94# avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.95# otherwise it will use all devices96spawn_kwargs["nprocs"] = nprocs97
98process_context = xmp.spawn(99self._wrapping_function,100args=(trainer, function, args, kwargs, return_queue),101start_method=self._start_method,102join=False, # we will join ourselves to get the process references103**spawn_kwargs,104)105# xla will not actually create processes if only 1 device106if process_context is not None:107self.procs = process_context.processes108while not process_context.join():109pass110
111worker_output = return_queue.get()112if trainer is None:113return worker_output114
115self._already_fit |= trainer.state.fn == TrainerFn.FITTING116self._recover_results_in_main_process(worker_output, trainer)117return worker_output.trainer_results118
119@override120def _wrapping_function(121self,122# XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing123# https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321124process_idx: int,125trainer: Optional["pl.Trainer"],126function: Callable,127args: Any,128kwargs: Any,129return_queue: Union[mp.SimpleQueue, queue.Queue],130global_states: Optional[_GlobalStateSnapshot] = None,131) -> None:132import torch_xla.core.xla_model as xm133
134if len(xm.get_xla_supported_devices()) > 1:135# `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)136# so when there's more than one (multithreading), objects need to be deep-copied137import copy138
139trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))140
141results = function(*args, **kwargs)142
143if trainer is not None:144results = self._collect_rank_zero_results(trainer, results)145
146if self._strategy.local_rank == 0:147return_queue.put(move_data_to_device(results, "cpu"))148
149_rank_teardown(self._strategy.local_rank)150
151@override152def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:153rank_zero_debug("Collecting results from rank 0 process.")154checkpoint_callback = trainer.checkpoint_callback155best_model_path = (156checkpoint_callback.best_model_path157if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")158else None159)160
161# save the last weights162weights_path = None163if trainer.state.fn == TrainerFn.FITTING:164# requires to compute the state_dict on all processes in case Metrics are present165state_dict = self._strategy.lightning_module_state_dict()166weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")167self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)168
169# We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training170if self._strategy.local_rank != 0:171return None172
173# add extra result data from trainer to send to main process174extra = self.get_extra_results(trainer)175
176return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)177