pytorch-lightning

Форк
0
107 строк · 3.8 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import warnings
16
from typing import Any, Type
17

18
from lightning.app.core.flow import LightningFlow
19
from lightning.app.core.work import LightningWork
20
from lightning.app.structures import List as _List
21
from lightning.app.utilities.cloud import is_running_in_cloud
22
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
23

24

25
class MultiNode(LightningFlow):
26
    def __init__(
27
        self,
28
        work_cls: Type["LightningWork"],
29
        num_nodes: int,
30
        cloud_compute: "CloudCompute",
31
        *work_args: Any,
32
        **work_kwargs: Any,
33
    ) -> None:
34
        """This component enables performing distributed multi-node multi-device training.
35

36
        Example::
37

38
            import torch
39

40
            from lightning.app import LightningWork, CloudCompute
41
            from lightning.components import MultiNode
42

43
            class AnyDistributedComponent(LightningWork):
44
                def run(
45
                    self,
46
                    main_address: str,
47
                    main_port: int,
48
                    node_rank: int,
49
                ):
50
                    print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
51

52

53
            compute = CloudCompute("gpu")
54
            app = LightningApp(
55
                MultiNode(
56
                    AnyDistributedComponent,
57
                    num_nodes=8,
58
                    cloud_compute=compute,
59
                )
60
            )
61

62
        Arguments:
63
            work_cls: The work to be executed
64
            num_nodes: Number of nodes. Gets ignored when running locally. Launch the app with --cloud to run on
65
                multiple cloud machines.
66
            cloud_compute: The cloud compute object used in the cloud. The value provided here gets ignored when
67
                running locally.
68
            work_args: Arguments to be provided to the work on instantiation.
69
            work_kwargs: Keywords arguments to be provided to the work on instantiation.
70

71
        """
72
        super().__init__()
73
        if num_nodes > 1 and not is_running_in_cloud():
74
            warnings.warn(
75
                f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally."
76
                " We assume you are debugging and will ignore the `num_nodes` argument."
77
                " To run on multiple nodes in the cloud, launch your app with `--cloud`."
78
            )
79
            num_nodes = 1
80
        self.ws = _List(*[
81
            work_cls(
82
                *work_args,
83
                cloud_compute=cloud_compute.clone(),
84
                **work_kwargs,
85
                parallel=True,
86
            )
87
            for _ in range(num_nodes)
88
        ])
89

90
    def run(self) -> None:
91
        # 1. Wait for all works to be started !
92
        if not all(w.internal_ip for w in self.ws):
93
            return
94

95
        # 2. Loop over all node machines
96
        for node_rank in range(len(self.ws)):
97
            # 3. Run the user code in a distributed way !
98
            self.ws[node_rank].run(
99
                main_address=self.ws[0].internal_ip,
100
                main_port=self.ws[0].port,
101
                num_nodes=len(self.ws),
102
                node_rank=node_rank,
103
            )
104

105
            # 4. Stop the machine when finished.
106
            if self.ws[node_rank].has_succeeded:
107
                self.ws[node_rank].stop()
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.