pytorch-lightning

Форк
0
130 строк · 4.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import importlib
16
import os
17
import warnings
18
from dataclasses import dataclass
19
from typing import Any, Callable, Protocol, Type, runtime_checkable
20

21
from lightning.app.components.multi_node.base import MultiNode
22
from lightning.app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
23
from lightning.app.core.work import LightningWork
24
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
25
from lightning.app.utilities.tracer import Tracer
26

27

28
@runtime_checkable
29
class _LightningTrainerWorkProtocol(Protocol):
30
    @staticmethod
31
    def run() -> None:
32
        """Run."""
33

34

35
@dataclass
36
class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor):
37
    @staticmethod
38
    def run(
39
        local_rank: int,
40
        work_run: Callable,
41
        main_address: str,
42
        main_port: int,
43
        num_nodes: int,
44
        node_rank: int,
45
        nprocs: int,
46
    ):
47
        trainers = []
48
        strategies = []
49
        mps_accelerators = []
50

51
        for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
52
            try:
53
                pkg = importlib.import_module(pkg_name)
54
                trainers.append(pkg.Trainer)
55
                strategies.append(pkg.strategies.DDPStrategy)
56
                mps_accelerators.append(pkg.accelerators.MPSAccelerator)
57
            except (ImportError, ModuleNotFoundError):
58
                continue
59

60
        # Used to configure PyTorch progress group
61
        os.environ["MASTER_ADDR"] = main_address
62
        os.environ["MASTER_PORT"] = str(main_port)
63

64
        # Used to hijack TorchElastic Cluster Environnement.
65
        os.environ["GROUP_RANK"] = str(node_rank)
66
        os.environ["RANK"] = str(local_rank + node_rank * nprocs)
67
        os.environ["LOCAL_RANK"] = str(local_rank)
68
        os.environ["WORLD_SIZE"] = str(num_nodes * nprocs)
69
        os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
70
        os.environ["TORCHELASTIC_RUN_ID"] = "1"
71

72
        # Used to pass information to the Trainer directly.
73
        def pre_fn(trainer, *args: Any, **kwargs: Any):
74
            kwargs["devices"] = nprocs
75
            kwargs["num_nodes"] = num_nodes
76
            if any(acc.is_available() for acc in mps_accelerators):
77
                old_acc_value = kwargs.get("accelerator", "auto")
78
                kwargs["accelerator"] = "cpu"
79

80
                if old_acc_value != kwargs["accelerator"]:
81
                    warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
82
            else:
83
                kwargs["accelerator"] = "auto"
84

85
            strategy = kwargs.get("strategy", None)
86
            if strategy:
87
                if isinstance(strategy, str):
88
                    if strategy == "ddp_spawn":
89
                        strategy = "ddp"
90
                    elif strategy == "ddp_sharded_spawn":
91
                        strategy = "ddp_sharded"
92
                elif isinstance(strategy, tuple(strategies)):
93
                    raise ValueError("DDP Spawned strategies aren't supported yet.")
94
                kwargs["strategy"] = strategy
95
            return {}, args, kwargs
96

97
        tracer = Tracer()
98
        for trainer in trainers:
99
            tracer.add_traced(trainer, "__init__", pre_fn=pre_fn)
100
        tracer._instrument()
101
        ret_val = work_run()
102
        tracer._restore()
103
        return ret_val
104

105

106
class LightningTrainerMultiNode(MultiNode):
107
    def __init__(
108
        self,
109
        work_cls: Type["LightningWork"],
110
        cloud_compute: "CloudCompute",
111
        num_nodes: int,
112
        *work_args: Any,
113
        **work_kwargs: Any,
114
    ) -> None:
115
        assert issubclass(work_cls, _LightningTrainerWorkProtocol)
116

117
        # Note: Private way to modify the work run executor
118
        # Probably exposed to the users in the future if needed.
119
        work_cls._run_executor_cls = _LightningTrainerRunExecutor
120

121
        super().__init__(
122
            work_cls,
123
            *work_args,
124
            num_nodes=num_nodes,
125
            cloud_compute=cloud_compute,
126
            **work_kwargs,
127
        )
128

129
        # the Trainer enables TensorBoard by default, so this is often an undesired directory to upload to the cloud
130
        self.lightningignore += ("lightning_logs",)
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.