pytorch-lightning

Форк
0
119 строк · 4.1 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from typing import Any, Callable, Protocol, Type, runtime_checkable
16

17
from lightning.app.components.multi_node.base import MultiNode
18
from lightning.app.core.queues import MultiProcessQueue
19
from lightning.app.core.work import LightningWork
20
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
21
from lightning.app.utilities.proxies import WorkRunExecutor, WorkStateObserver, _proxy_setattr, unwrap
22

23

24
@runtime_checkable
25
class _PyTorchSpawnWorkProtocol(Protocol):
26
    def run(
27
        self,
28
        world_size: int,
29
        node_rank: int,
30
        global_rank: int,
31
        local_rank: int,
32
    ) -> None:
33
        pass
34

35

36
class _PyTorchSpawnRunExecutor(WorkRunExecutor):
37
    enable_start_observer: bool = False
38

39
    def __call__(
40
        self,
41
        main_address: str,
42
        main_port: int,
43
        num_nodes: int,
44
        node_rank: int,
45
    ):
46
        import torch
47

48
        with self.enable_spawn():
49
            nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
50
            queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
51
            torch.multiprocessing.spawn(
52
                self.dispatch_run,
53
                args=(self.__class__, self.work, queue, main_address, main_port, num_nodes, node_rank, nprocs),
54
                nprocs=nprocs,
55
            )
56

57
    @staticmethod
58
    def dispatch_run(local_rank, cls, work, delta_queue, *args: Any, **kwargs: Any):
59
        if local_rank == 0:
60
            if isinstance(delta_queue, dict):
61
                delta_queue = cls.process_queue(delta_queue)
62
                work._request_queue = cls.process_queue(work._request_queue)
63
                work._response_queue = cls.process_queue(work._response_queue)
64

65
            state_observer = WorkStateObserver(work, delta_queue=delta_queue)
66
            state_observer.start()
67
            _proxy_setattr(work, delta_queue, state_observer)
68

69
        cls.run(local_rank, unwrap(work.run), *args, **kwargs)
70

71
        if local_rank == 0:
72
            state_observer.join(0)
73

74
    @staticmethod
75
    def run(
76
        local_rank: int,
77
        work_run: Callable,
78
        main_address: str,
79
        main_port: int,
80
        num_nodes: int,
81
        node_rank: int,
82
        nprocs: int,
83
    ):
84
        import torch
85

86
        # 1. Setting distributed environment
87
        global_rank = local_rank + node_rank * nprocs
88
        world_size = num_nodes * nprocs
89

90
        if torch.distributed.is_available():
91
            if not torch.distributed.is_initialized():
92
                torch.distributed.init_process_group(
93
                    "nccl" if torch.cuda.is_available() else "gloo",
94
                    rank=global_rank,
95
                    world_size=world_size,
96
                    init_method=f"tcp://{main_address}:{main_port}",
97
                )
98
        elif world_size > 1:
99
            raise Exception("Torch distributed should be available.")
100

101
        return work_run(world_size, node_rank, global_rank, local_rank)
102

103

104
class PyTorchSpawnMultiNode(MultiNode):
105
    def __init__(
106
        self,
107
        work_cls: Type["LightningWork"],
108
        cloud_compute: "CloudCompute",
109
        num_nodes: int,
110
        *work_args: Any,
111
        **work_kwargs: Any,
112
    ) -> None:
113
        assert issubclass(work_cls, _PyTorchSpawnWorkProtocol)
114

115
        # Note: Private way to modify the work run executor
116
        # Probably exposed to the users in the future if needed.
117
        work_cls._run_executor_cls = _PyTorchSpawnRunExecutor
118

119
        super().__init__(work_cls, num_nodes, cloud_compute, *work_args, **work_kwargs)
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.