pytorch-lightning
130 строк · 4.6 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import importlib
16import os
17import warnings
18from dataclasses import dataclass
19from typing import Any, Callable, Protocol, Type, runtime_checkable
20
21from lightning.app.components.multi_node.base import MultiNode
22from lightning.app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
23from lightning.app.core.work import LightningWork
24from lightning.app.utilities.packaging.cloud_compute import CloudCompute
25from lightning.app.utilities.tracer import Tracer
26
27
28@runtime_checkable
29class _LightningTrainerWorkProtocol(Protocol):
30@staticmethod
31def run() -> None:
32"""Run."""
33
34
35@dataclass
36class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor):
37@staticmethod
38def run(
39local_rank: int,
40work_run: Callable,
41main_address: str,
42main_port: int,
43num_nodes: int,
44node_rank: int,
45nprocs: int,
46):
47trainers = []
48strategies = []
49mps_accelerators = []
50
51for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
52try:
53pkg = importlib.import_module(pkg_name)
54trainers.append(pkg.Trainer)
55strategies.append(pkg.strategies.DDPStrategy)
56mps_accelerators.append(pkg.accelerators.MPSAccelerator)
57except (ImportError, ModuleNotFoundError):
58continue
59
60# Used to configure PyTorch progress group
61os.environ["MASTER_ADDR"] = main_address
62os.environ["MASTER_PORT"] = str(main_port)
63
64# Used to hijack TorchElastic Cluster Environnement.
65os.environ["GROUP_RANK"] = str(node_rank)
66os.environ["RANK"] = str(local_rank + node_rank * nprocs)
67os.environ["LOCAL_RANK"] = str(local_rank)
68os.environ["WORLD_SIZE"] = str(num_nodes * nprocs)
69os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
70os.environ["TORCHELASTIC_RUN_ID"] = "1"
71
72# Used to pass information to the Trainer directly.
73def pre_fn(trainer, *args: Any, **kwargs: Any):
74kwargs["devices"] = nprocs
75kwargs["num_nodes"] = num_nodes
76if any(acc.is_available() for acc in mps_accelerators):
77old_acc_value = kwargs.get("accelerator", "auto")
78kwargs["accelerator"] = "cpu"
79
80if old_acc_value != kwargs["accelerator"]:
81warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
82else:
83kwargs["accelerator"] = "auto"
84
85strategy = kwargs.get("strategy", None)
86if strategy:
87if isinstance(strategy, str):
88if strategy == "ddp_spawn":
89strategy = "ddp"
90elif strategy == "ddp_sharded_spawn":
91strategy = "ddp_sharded"
92elif isinstance(strategy, tuple(strategies)):
93raise ValueError("DDP Spawned strategies aren't supported yet.")
94kwargs["strategy"] = strategy
95return {}, args, kwargs
96
97tracer = Tracer()
98for trainer in trainers:
99tracer.add_traced(trainer, "__init__", pre_fn=pre_fn)
100tracer._instrument()
101ret_val = work_run()
102tracer._restore()
103return ret_val
104
105
106class LightningTrainerMultiNode(MultiNode):
107def __init__(
108self,
109work_cls: Type["LightningWork"],
110cloud_compute: "CloudCompute",
111num_nodes: int,
112*work_args: Any,
113**work_kwargs: Any,
114) -> None:
115assert issubclass(work_cls, _LightningTrainerWorkProtocol)
116
117# Note: Private way to modify the work run executor
118# Probably exposed to the users in the future if needed.
119work_cls._run_executor_cls = _LightningTrainerRunExecutor
120
121super().__init__(
122work_cls,
123*work_args,
124num_nodes=num_nodes,
125cloud_compute=cloud_compute,
126**work_kwargs,
127)
128
129# the Trainer enables TensorBoard by default, so this is often an undesired directory to upload to the cloud
130self.lightningignore += ("lightning_logs",)
131