pytorch-lightning
203 строки · 7.1 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16from typing import Any, Dict, List, Optional, Tuple, Type, Union
17
18from lightning.app.components.python import TracerPythonScript
19from lightning.app.core.flow import LightningFlow
20from lightning.app.storage.path import Path
21from lightning.app.structures import List as _List
22from lightning.app.utilities.app_helpers import Logger
23from lightning.app.utilities.packaging.cloud_compute import CloudCompute
24
25_logger = Logger(__name__)
26
27
28class PyTorchLightningScriptRunner(TracerPythonScript):
29def __init__(
30self,
31script_path: str,
32script_args: Optional[Union[list, str]] = None,
33node_rank: int = 1,
34num_nodes: int = 1,
35sanity_serving: bool = False,
36cloud_compute: Optional[CloudCompute] = None,
37parallel: bool = True,
38raise_exception: bool = True,
39env: Optional[Dict[str, Any]] = None,
40**kwargs: Any,
41):
42super().__init__(
43script_path,
44script_args,
45raise_exception=raise_exception,
46parallel=parallel,
47cloud_compute=cloud_compute,
48**kwargs,
49)
50self.node_rank = node_rank
51self.num_nodes = num_nodes
52self.best_model_path = None
53self.best_model_score = None
54self.monitor = None
55self.sanity_serving = sanity_serving
56self.has_finished = False
57self.env = env
58
59def configure_tracer(self):
60from lightning.pytorch import Trainer
61
62tracer = super().configure_tracer()
63tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
64return tracer
65
66def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs: Any) -> None:
67if not internal_urls:
68# Note: This is called only once.
69_logger.info(f"The node {self.node_rank} started !")
70return None
71
72if self.env:
73os.environ.update(self.env)
74
75distributed_env_vars = {
76"MASTER_ADDR": internal_urls[0][0],
77"MASTER_PORT": str(internal_urls[0][1]),
78"NODE_RANK": str(self.node_rank),
79"PL_TRAINER_NUM_NODES": str(self.num_nodes),
80"PL_TRAINER_DEVICES": "auto",
81"PL_TRAINER_ACCELERATOR": "auto",
82}
83
84os.environ.update(distributed_env_vars)
85return super().run(**kwargs)
86
87def on_after_run(self, script_globals):
88from lightning.pytorch import Trainer
89from lightning.pytorch.cli import LightningCLI
90
91for v in script_globals.values():
92if isinstance(v, LightningCLI):
93trainer = v.trainer
94break
95if isinstance(v, Trainer):
96trainer = v
97break
98else:
99raise RuntimeError("No trainer instance found.")
100
101self.monitor = trainer.checkpoint_callback.monitor
102
103if trainer.checkpoint_callback.best_model_score:
104self.best_model_path = Path(trainer.checkpoint_callback.best_model_path)
105self.best_model_score = float(trainer.checkpoint_callback.best_model_score)
106else:
107self.best_model_path = Path(trainer.checkpoint_callback.last_model_path)
108
109self.has_finished = True
110
111def _trainer_init_pre_middleware(self, trainer, *args: Any, **kwargs: Any):
112if self.node_rank != 0:
113return {}, args, kwargs
114
115from lightning.pytorch.serve import ServableModuleValidator
116
117callbacks = kwargs.get("callbacks", [])
118if self.sanity_serving:
119callbacks = callbacks + [ServableModuleValidator()]
120kwargs["callbacks"] = callbacks
121return {}, args, kwargs
122
123@property
124def is_running_in_cloud(self) -> bool:
125return "LIGHTNING_APP_STATE_URL" in os.environ
126
127
128class LightningTrainerScript(LightningFlow):
129def __init__(
130self,
131script_path: str,
132script_args: Optional[Union[list, str]] = None,
133num_nodes: int = 1,
134cloud_compute: CloudCompute = CloudCompute("default"),
135sanity_serving: bool = False,
136script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner,
137**script_runner_kwargs,
138):
139"""This component enables performing distributed multi-node multi-device training.
140
141Example::
142
143from lightning.app import LightningApp
144from lightning.app.components.training import LightningTrainerScript
145from lightning.app.utilities.packaging.cloud_compute import CloudCompute
146
147app = LightningApp(
148LightningTrainerScript(
149"train.py",
150num_nodes=2,
151cloud_compute=CloudCompute("gpu"),
152),
153)
154
155Arguments:
156script_path: Path to the script to be executed.
157script_args: The arguments to be pass to the script.
158num_nodes: Number of nodes.
159cloud_compute: The cloud compute object used in the cloud.
160sanity_serving: Whether to validate that the model correctly implements
161the ServableModule API
162
163"""
164super().__init__()
165self.script_path = script_path
166self.script_args = script_args
167self.num_nodes = num_nodes
168self.sanity_serving = sanity_serving
169self._script_runner = script_runner
170self._script_runner_kwargs = script_runner_kwargs
171
172self.ws = _List()
173for node_rank in range(self.num_nodes):
174self.ws.append(
175self._script_runner(
176script_path=self.script_path,
177script_args=self.script_args,
178cloud_compute=cloud_compute,
179node_rank=node_rank,
180sanity_serving=self.sanity_serving,
181num_nodes=self.num_nodes,
182**self._script_runner_kwargs,
183)
184)
185
186def run(self, **run_kwargs):
187for work in self.ws:
188if all(w.internal_ip for w in self.ws):
189internal_urls = [(w.internal_ip, w.port) for w in self.ws]
190work.run(internal_urls=internal_urls, **run_kwargs)
191if all(w.has_finished for w in self.ws):
192for w in self.ws:
193w.stop()
194else:
195work.run()
196
197@property
198def best_model_score(self) -> Optional[float]:
199return self.ws[0].best_model_score
200
201@property
202def best_model_paths(self) -> List[Optional[Path]]:
203return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))]
204