pytorch-lightning

Форк
0
132 строки · 4.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import importlib
16
import os
17
import warnings
18
from dataclasses import dataclass
19
from typing import Any, Callable, Protocol, Type, runtime_checkable
20

21
from lightning.app.components.multi_node.base import MultiNode
22
from lightning.app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
23
from lightning.app.core.work import LightningWork
24
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
25
from lightning.app.utilities.tracer import Tracer
26

27

28
@runtime_checkable
29
class _FabricWorkProtocol(Protocol):
30
    @staticmethod
31
    def run() -> None:
32
        """Run."""
33

34

35
@dataclass
36
class _FabricRunExecutor(_PyTorchSpawnRunExecutor):
37
    @staticmethod
38
    def run(
39
        local_rank: int,
40
        work_run: Callable,
41
        main_address: str,
42
        main_port: int,
43
        num_nodes: int,
44
        node_rank: int,
45
        nprocs: int,
46
    ):
47
        fabrics = []
48
        strategies = []
49
        mps_accelerators = []
50

51
        for pkg_name in ("lightning.fabric", "lightning_" + "fabric"):
52
            try:
53
                pkg = importlib.import_module(pkg_name)
54
                fabrics.append(pkg.Fabric)
55
                strategies.append(pkg.strategies.DDPStrategy)
56
                mps_accelerators.append(pkg.accelerators.MPSAccelerator)
57
            except (ImportError, ModuleNotFoundError):
58
                continue
59

60
        # Used to configure PyTorch progress group
61
        os.environ["MASTER_ADDR"] = main_address
62
        os.environ["MASTER_PORT"] = str(main_port)
63

64
        # Used to hijack TorchElastic Cluster Environnement.
65
        os.environ["GROUP_RANK"] = str(node_rank)
66
        os.environ["RANK"] = str(local_rank + node_rank * nprocs)
67
        os.environ["LOCAL_RANK"] = str(local_rank)
68
        os.environ["WORLD_SIZE"] = str(num_nodes * nprocs)
69
        os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
70
        os.environ["TORCHELASTIC_RUN_ID"] = "1"
71

72
        # Used to force Fabric to setup the distributed environnement.
73
        os.environ["LT_CLI_USED"] = "1"
74

75
        # Used to pass information to Fabric directly.
76
        def pre_fn(fabric, *args: Any, **kwargs: Any):
77
            kwargs["devices"] = nprocs
78
            kwargs["num_nodes"] = num_nodes
79

80
            if any(acc.is_available() for acc in mps_accelerators):
81
                old_acc_value = kwargs.get("accelerator", "auto")
82
                kwargs["accelerator"] = "cpu"
83

84
                if old_acc_value != kwargs["accelerator"]:
85
                    warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
86
            else:
87
                kwargs["accelerator"] = "auto"
88
            strategy = kwargs.get("strategy", None)
89
            if strategy:
90
                if isinstance(strategy, str):
91
                    if strategy == "ddp_spawn":
92
                        strategy = "ddp"
93
                    elif strategy == "ddp_sharded_spawn":
94
                        strategy = "ddp_sharded"
95
                elif isinstance(strategy, tuple(strategies)) and strategy._start_method in ("spawn", "fork"):
96
                    raise ValueError("DDP Spawned strategies aren't supported yet.")
97

98
            kwargs["strategy"] = strategy
99

100
            return {}, args, kwargs
101

102
        tracer = Tracer()
103
        for lf in fabrics:
104
            tracer.add_traced(lf, "__init__", pre_fn=pre_fn)
105
        tracer._instrument()
106
        ret_val = work_run()
107
        tracer._restore()
108
        return ret_val
109

110

111
class FabricMultiNode(MultiNode):
112
    def __init__(
113
        self,
114
        work_cls: Type["LightningWork"],
115
        cloud_compute: "CloudCompute",
116
        num_nodes: int,
117
        *work_args: Any,
118
        **work_kwargs: Any,
119
    ) -> None:
120
        assert issubclass(work_cls, _FabricWorkProtocol)
121

122
        # Note: Private way to modify the work run executor
123
        # Probably exposed to the users in the future if needed.
124
        work_cls._run_executor_cls = _FabricRunExecutor
125

126
        super().__init__(
127
            work_cls,
128
            *work_args,
129
            num_nodes=num_nodes,
130
            cloud_compute=cloud_compute,
131
            **work_kwargs,
132
        )
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.