pytorch-lightning

Форк
0
320 строк · 13.9 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import logging
15
import os
16
import queue
17
import tempfile
18
from contextlib import suppress
19
from dataclasses import dataclass
20
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union
21

22
import numpy as np
23
import torch
24
import torch.backends.cudnn
25
import torch.multiprocessing as mp
26
from lightning_utilities.core.apply_func import apply_to_collection
27
from torch import Tensor
28
from typing_extensions import override
29

30
import lightning.pytorch as pl
31
from lightning.fabric.strategies.launchers.multiprocessing import (
32
    _check_bad_cuda_fork,
33
    _check_missing_main_guard,
34
    _disable_module_memory_sharing,
35
)
36
from lightning.fabric.utilities import move_data_to_device
37
from lightning.fabric.utilities.distributed import _set_num_threads_if_needed
38
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
39
from lightning.fabric.utilities.types import _PATH
40
from lightning.pytorch.accelerators import CPUAccelerator
41
from lightning.pytorch.strategies.launchers.launcher import _Launcher
42
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM
43
from lightning.pytorch.trainer.states import TrainerFn, TrainerState
44
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
45

46
log = logging.getLogger(__name__)
47

48

49
class _MultiProcessingLauncher(_Launcher):
50
    r"""Launches processes that run a given function in parallel, and joins them all at the end.
51

52
    The main process in which this launcher is invoked creates N so-called worker processes (using
53
    :func:`torch.multiprocessing.start_processes`) that run the given function.
54
    Worker processes have a rank that ranges from 0 to N - 1.
55

56
    Note:
57
        - This launcher requires all objects to be pickleable.
58
        - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
59
        - With start method 'fork' the user must ensure that no CUDA context gets created in the main process before
60
          the launcher is invoked. E.g., one should avoid creating cuda tensors or calling ``torch.cuda.*`` functions
61
          before calling ``Trainer.fit``.
62

63
    Args:
64
        strategy: A reference to the strategy that is used together with this launcher.
65
        start_method: The method how to start the processes.
66
            - 'spawn': The default start method. Requires all objects to be pickleable.
67
            - 'fork': Preferable for IPython/Jupyter environments where 'spawn' is not available. Not available on
68
              the Windows platform for example.
69
            - 'forkserver': Alternative implementation to 'fork'.
70

71
    """
72

73
    def __init__(
74
        self, strategy: "pl.strategies.ParallelStrategy", start_method: Literal["spawn", "fork", "forkserver"] = "spawn"
75
    ) -> None:
76
        self._strategy = strategy
77
        self._start_method = start_method
78
        if start_method not in mp.get_all_start_methods():
79
            raise ValueError(
80
                f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
81
                f" {', '.join(mp.get_all_start_methods())}"
82
            )
83
        self.procs: List[mp.Process] = []
84
        self._already_fit = False
85

86
    @property
87
    @override
88
    def is_interactive_compatible(self) -> bool:
89
        # The start method 'spawn' is not supported in interactive environments
90
        # The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA
91
        # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550
92
        return self._start_method == "fork"
93

94
    @override
95
    def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
96
        """Launches processes that run the given function in parallel.
97

98
        The function is allowed to have a return value. However, when all processes join, only the return value
99
        of worker process 0 gets returned from this `launch` method in the main process.
100

101
        Arguments:
102
            function: The entry point for all launched processes.
103
            *args: Optional positional arguments to be passed to the given function.
104
            trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer` for which
105
                a selected set of attributes get restored in the main process after processes join.
106
            **kwargs: Optional keyword arguments to be passed to the given function.
107

108
        """
109
        if self._start_method in ("fork", "forkserver"):
110
            _check_bad_cuda_fork()
111
        if self._start_method == "spawn":
112
            _check_missing_main_guard()
113
        if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
114
            # resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
115
            raise NotImplementedError(
116
                "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
117
                " supported. You can work around this limitation by creating a new Trainer instance and passing the"
118
                " `fit(ckpt_path=...)` argument."
119
            )
120

121
        # The default cluster environment in Lightning chooses a random free port number
122
        # This needs to be done in the main process here before starting processes to ensure each rank will connect
123
        # through the same port
124
        assert self._strategy.cluster_environment is not None
125
        os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
126

127
        context = mp.get_context(self._start_method)
128
        return_queue = context.SimpleQueue()
129

130
        if self._start_method == "spawn":
131
            global_states = _GlobalStateSnapshot.capture()
132
            process_args = [trainer, function, args, kwargs, return_queue, global_states]
133
        else:
134
            process_args = [trainer, function, args, kwargs, return_queue]
135

136
        process_context = mp.start_processes(
137
            self._wrapping_function,
138
            args=process_args,
139
            nprocs=self._strategy.num_processes,
140
            start_method=self._start_method,
141
            join=False,  # we will join ourselves to get the process references
142
        )
143
        self.procs = process_context.processes
144
        while not process_context.join():
145
            pass
146

147
        worker_output = return_queue.get()
148
        if trainer is None:
149
            return worker_output
150

151
        self._already_fit |= trainer.state.fn == TrainerFn.FITTING
152
        self._recover_results_in_main_process(worker_output, trainer)
153
        return worker_output.trainer_results
154

155
    def _wrapping_function(
156
        self,
157
        process_idx: int,
158
        trainer: Optional["pl.Trainer"],
159
        function: Callable,
160
        args: Any,
161
        kwargs: Any,
162
        return_queue: Union[mp.SimpleQueue, queue.Queue],
163
        global_states: Optional["_GlobalStateSnapshot"] = None,
164
    ) -> None:
165
        if global_states:
166
            global_states.restore()
167
        if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
168
            args, kwargs = _disable_module_memory_sharing((args, kwargs))
169

170
        _set_num_threads_if_needed(num_processes=self._strategy.num_processes)
171

172
        os.environ["LOCAL_RANK"] = str(process_idx)
173
        results = function(*args, **kwargs)
174

175
        if trainer is not None:
176
            results = self._collect_rank_zero_results(trainer, results)
177

178
        if process_idx == 0:
179
            return_queue.put(move_data_to_device(results, "cpu"))
180

181
    def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", trainer: "pl.Trainer") -> None:
182
        # transfer back the best path to the trainer
183
        if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"):
184
            trainer.checkpoint_callback.best_model_path = str(worker_output.best_model_path)
185

186
        # TODO: pass also best score
187
        # load last weights
188
        if worker_output.weights_path is not None:
189
            ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
190
            # choose non-strict loading of parameters on the main process, because the model's composition
191
            # could have changed in the worker process (layers added or removed)
192
            trainer.lightning_module.load_state_dict(ckpt, strict=False)
193
            self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)
194

195
        trainer.state = worker_output.trainer_state
196

197
        # get the `callback_metrics` and set it to the trainer
198
        self.update_main_process_results(trainer, worker_output.extra)
199

200
    def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
201
        rank_zero_debug("Collecting results from rank 0 process.")
202
        checkpoint_callback = trainer.checkpoint_callback
203
        best_model_path = (
204
            checkpoint_callback.best_model_path
205
            if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
206
            else None
207
        )
208

209
        # requires to compute the state_dict on all processes in case Metrics are present
210
        state_dict = trainer.lightning_module.state_dict()
211

212
        if self._strategy.local_rank != 0:
213
            return None
214

215
        # save the last weights
216
        weights_path = None
217
        if trainer.state.fn == TrainerFn.FITTING:
218
            # use tempdir here to avoid race conditions because the filesystem may be shared between nodes
219
            weights_path = os.path.join(tempfile.mkdtemp(), ".temp.ckpt")
220
            self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)
221

222
        # add extra result data from trainer to send to main process
223
        extra = self.get_extra_results(trainer)
224

225
        return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
226

227
    def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
228
        """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
229
        avoid issues with memory sharing, we cast the data to numpy.
230

231
        Args:
232
            trainer: reference to the Trainer.
233

234
        Returns:
235
            A dictionary with items to send back to the main process where :meth:`update_main_process_results` will
236
            process this output.
237

238
        """
239
        callback_metrics: dict = apply_to_collection(
240
            trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
241
        )  # send as numpy to avoid issues with memory sharing
242
        return {"callback_metrics": callback_metrics}
243

244
    def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
245
        """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
246
        cast back the data to ``torch.Tensor``.
247

248
        Args:
249
            trainer: reference to the Trainer.
250
            extra: A dictionary with trainer state that was sent from the worker process and needs to be restored
251
                on the current trainer.
252

253
        """
254
        # NOTE: `get_extra_results` needs to be called before
255
        callback_metrics = extra["callback_metrics"]
256
        trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
257

258
    @override
259
    def kill(self, signum: _SIGNUM) -> None:
260
        for proc in self.procs:
261
            if proc.is_alive() and proc.pid is not None:
262
                log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
263
                with suppress(ProcessLookupError):
264
                    os.kill(proc.pid, signum)
265

266
    def __getstate__(self) -> Dict:
267
        state = self.__dict__.copy()
268
        state["procs"] = []  # SpawnProcess can't be pickled
269
        return state
270

271

272
class _WorkerOutput(NamedTuple):
273
    best_model_path: Optional[_PATH]
274
    weights_path: Optional[_PATH]
275
    trainer_state: TrainerState
276
    trainer_results: Any
277
    extra: Dict[str, Any]
278

279

280
@dataclass
281
class _GlobalStateSnapshot:
282
    """Captures a hand-selected set of (global) variables in modules and provides a way to restore them.
283

284
    It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state
285
    across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`.
286

287
    Example:
288

289
        .. code-block:: python
290

291
            # in main process
292
            snapshot = _GlobalStateSnapshot.capture()
293

294
            # in worker process
295
            snapshot.restore()
296

297
    """
298

299
    use_deterministic_algorithms: bool
300
    use_deterministic_algorithms_warn_only: bool
301
    cudnn_benchmark: bool
302
    rng_states: Dict[str, Any]
303

304
    @classmethod
305
    def capture(cls) -> "_GlobalStateSnapshot":
306
        """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker process."""
307
        return cls(
308
            use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
309
            use_deterministic_algorithms_warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
310
            cudnn_benchmark=torch.backends.cudnn.benchmark,
311
            rng_states=_collect_rng_states(),
312
        )
313

314
    def restore(self) -> None:
315
        """Restores all globals to the values captured in the :meth:`capture` method."""
316
        torch.use_deterministic_algorithms(
317
            self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only
318
        )
319
        torch.backends.cudnn.benchmark = self.cudnn_benchmark
320
        _set_rng_states(self.rng_states)
321

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.