pytorch-lightning

Форк
0
203 строки · 7.1 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
from typing import Any, Dict, List, Optional, Tuple, Type, Union
17

18
from lightning.app.components.python import TracerPythonScript
19
from lightning.app.core.flow import LightningFlow
20
from lightning.app.storage.path import Path
21
from lightning.app.structures import List as _List
22
from lightning.app.utilities.app_helpers import Logger
23
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
24

25
_logger = Logger(__name__)
26

27

28
class PyTorchLightningScriptRunner(TracerPythonScript):
29
    def __init__(
30
        self,
31
        script_path: str,
32
        script_args: Optional[Union[list, str]] = None,
33
        node_rank: int = 1,
34
        num_nodes: int = 1,
35
        sanity_serving: bool = False,
36
        cloud_compute: Optional[CloudCompute] = None,
37
        parallel: bool = True,
38
        raise_exception: bool = True,
39
        env: Optional[Dict[str, Any]] = None,
40
        **kwargs: Any,
41
    ):
42
        super().__init__(
43
            script_path,
44
            script_args,
45
            raise_exception=raise_exception,
46
            parallel=parallel,
47
            cloud_compute=cloud_compute,
48
            **kwargs,
49
        )
50
        self.node_rank = node_rank
51
        self.num_nodes = num_nodes
52
        self.best_model_path = None
53
        self.best_model_score = None
54
        self.monitor = None
55
        self.sanity_serving = sanity_serving
56
        self.has_finished = False
57
        self.env = env
58

59
    def configure_tracer(self):
60
        from lightning.pytorch import Trainer
61

62
        tracer = super().configure_tracer()
63
        tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
64
        return tracer
65

66
    def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs: Any) -> None:
67
        if not internal_urls:
68
            # Note: This is called only once.
69
            _logger.info(f"The node {self.node_rank} started !")
70
            return None
71

72
        if self.env:
73
            os.environ.update(self.env)
74

75
        distributed_env_vars = {
76
            "MASTER_ADDR": internal_urls[0][0],
77
            "MASTER_PORT": str(internal_urls[0][1]),
78
            "NODE_RANK": str(self.node_rank),
79
            "PL_TRAINER_NUM_NODES": str(self.num_nodes),
80
            "PL_TRAINER_DEVICES": "auto",
81
            "PL_TRAINER_ACCELERATOR": "auto",
82
        }
83

84
        os.environ.update(distributed_env_vars)
85
        return super().run(**kwargs)
86

87
    def on_after_run(self, script_globals):
88
        from lightning.pytorch import Trainer
89
        from lightning.pytorch.cli import LightningCLI
90

91
        for v in script_globals.values():
92
            if isinstance(v, LightningCLI):
93
                trainer = v.trainer
94
                break
95
            if isinstance(v, Trainer):
96
                trainer = v
97
                break
98
        else:
99
            raise RuntimeError("No trainer instance found.")
100

101
        self.monitor = trainer.checkpoint_callback.monitor
102

103
        if trainer.checkpoint_callback.best_model_score:
104
            self.best_model_path = Path(trainer.checkpoint_callback.best_model_path)
105
            self.best_model_score = float(trainer.checkpoint_callback.best_model_score)
106
        else:
107
            self.best_model_path = Path(trainer.checkpoint_callback.last_model_path)
108

109
        self.has_finished = True
110

111
    def _trainer_init_pre_middleware(self, trainer, *args: Any, **kwargs: Any):
112
        if self.node_rank != 0:
113
            return {}, args, kwargs
114

115
        from lightning.pytorch.serve import ServableModuleValidator
116

117
        callbacks = kwargs.get("callbacks", [])
118
        if self.sanity_serving:
119
            callbacks = callbacks + [ServableModuleValidator()]
120
        kwargs["callbacks"] = callbacks
121
        return {}, args, kwargs
122

123
    @property
124
    def is_running_in_cloud(self) -> bool:
125
        return "LIGHTNING_APP_STATE_URL" in os.environ
126

127

128
class LightningTrainerScript(LightningFlow):
129
    def __init__(
130
        self,
131
        script_path: str,
132
        script_args: Optional[Union[list, str]] = None,
133
        num_nodes: int = 1,
134
        cloud_compute: CloudCompute = CloudCompute("default"),
135
        sanity_serving: bool = False,
136
        script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner,
137
        **script_runner_kwargs,
138
    ):
139
        """This component enables performing distributed multi-node multi-device training.
140

141
        Example::
142

143
            from lightning.app import LightningApp
144
            from lightning.app.components.training import LightningTrainerScript
145
            from lightning.app.utilities.packaging.cloud_compute import CloudCompute
146

147
            app = LightningApp(
148
                LightningTrainerScript(
149
                    "train.py",
150
                    num_nodes=2,
151
                    cloud_compute=CloudCompute("gpu"),
152
                ),
153
            )
154

155
        Arguments:
156
            script_path: Path to the script to be executed.
157
            script_args: The arguments to be pass to the script.
158
            num_nodes: Number of nodes.
159
            cloud_compute: The cloud compute object used in the cloud.
160
            sanity_serving: Whether to validate that the model correctly implements
161
                the ServableModule API
162

163
        """
164
        super().__init__()
165
        self.script_path = script_path
166
        self.script_args = script_args
167
        self.num_nodes = num_nodes
168
        self.sanity_serving = sanity_serving
169
        self._script_runner = script_runner
170
        self._script_runner_kwargs = script_runner_kwargs
171

172
        self.ws = _List()
173
        for node_rank in range(self.num_nodes):
174
            self.ws.append(
175
                self._script_runner(
176
                    script_path=self.script_path,
177
                    script_args=self.script_args,
178
                    cloud_compute=cloud_compute,
179
                    node_rank=node_rank,
180
                    sanity_serving=self.sanity_serving,
181
                    num_nodes=self.num_nodes,
182
                    **self._script_runner_kwargs,
183
                )
184
            )
185

186
    def run(self, **run_kwargs):
187
        for work in self.ws:
188
            if all(w.internal_ip for w in self.ws):
189
                internal_urls = [(w.internal_ip, w.port) for w in self.ws]
190
                work.run(internal_urls=internal_urls, **run_kwargs)
191
                if all(w.has_finished for w in self.ws):
192
                    for w in self.ws:
193
                        w.stop()
194
            else:
195
                work.run()
196

197
    @property
198
    def best_model_score(self) -> Optional[float]:
199
        return self.ws[0].best_model_score
200

201
    @property
202
    def best_model_paths(self) -> List[Optional[Path]]:
203
        return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))]
204

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.