pytorch-lightning

Форк
0
200 строк · 7.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import signal
17
import sys
18
from copy import deepcopy
19
from typing import Any, Dict, List, Optional, Union
20

21
from typing_extensions import TypedDict
22

23
from lightning.app.core.work import LightningWork
24
from lightning.app.storage.drive import Drive
25
from lightning.app.storage.payload import Payload
26
from lightning.app.utilities.app_helpers import Logger, _collect_child_process_pids
27
from lightning.app.utilities.packaging.tarfile import clean_tarfile, extract_tarfile
28
from lightning.app.utilities.tracer import Tracer
29

30
logger = Logger(__name__)
31

32

33
class Code(TypedDict):
34
    drive: Drive
35
    name: str
36

37

38
class TracerPythonScript(LightningWork):
39
    _start_method = "spawn"
40

41
    def on_before_run(self):
42
        """Called before the python script is executed."""
43

44
    def on_after_run(self, res: Any):
45
        """Called after the python script is executed."""
46
        for name in self.outputs:
47
            setattr(self, name, Payload(res[name]))
48

49
    def configure_tracer(self) -> Tracer:
50
        """Override this hook to customize your tracer when running PythonScript."""
51
        return Tracer()
52

53
    def __init__(
54
        self,
55
        script_path: str,
56
        script_args: Optional[Union[list, str]] = None,
57
        outputs: Optional[List[str]] = None,
58
        env: Optional[Dict] = None,
59
        code: Optional[Code] = None,
60
        **kwargs: Any,
61
    ):
62
        """The TracerPythonScript class enables to easily run a python script.
63

64
        When subclassing this class, you can configure your own :class:`~lightning.app.utilities.tracer.Tracer`
65
        by :meth:`~lightning.app.components.python.tracer.TracerPythonScript.configure_tracer` method.
66

67
        The tracer is quite a magical class. It enables you to inject code into a script execution without changing it.
68

69
        Arguments:
70
            script_path: Path of the python script to run.
71
            script_path: The arguments to be passed to the script.
72
            outputs: Collection of object names to collect after the script execution.
73
            env: Environment variables to be passed to the script.
74
            kwargs: LightningWork Keyword arguments.
75

76
        Raises:
77
            FileNotFoundError: If the provided `script_path` doesn't exists.
78

79
        **How does it work?**
80

81
        It works by executing the python script with python built-in `runpy
82
        <https://docs.python.org/3/library/runpy.html>`_ run_path method.
83
        This method takes any python globals before executing the script,
84
        e.g., you can modify classes or function from the script.
85

86
        Example:
87

88
            >>> from lightning.app.components.python import TracerPythonScript
89
            >>> f = open("a.py", "w")
90
            >>> f.write("print('Hello World !')")
91
            22
92
            >>> f.close()
93
            >>> python_script = TracerPythonScript("a.py")
94
            >>> python_script.run()
95
            Hello World !
96
            >>> os.remove("a.py")
97

98
        In the example below, we subclass the  :class:`~lightning.app.components.python.TracerPythonScript`
99
        component and override its configure_tracer method.
100

101
        Using the Tracer, we are patching the ``__init__`` method of the PyTorch Lightning Trainer.
102
        Once the script starts running and if a Trainer is instantiated, the provided ``pre_fn`` is
103
        called and we inject a Lightning callback.
104

105
        This callback has a reference to the work and on every batch end, we are capturing the
106
        trainer ``global_step`` and ``best_model_path``.
107

108
        Even more interesting, this component works for ANY PyTorch Lightning script and
109
        its state can be used in real time in a UI.
110

111
        .. literalinclude:: ../../../../examples/app/components/python/component_tracer.py
112
            :language: python
113

114

115
        Once implemented, this component can easily be integrated within a larger app
116
        to execute a specific python script.
117

118
        .. literalinclude:: ../../../../examples/app/components/python/app.py
119
            :language: python
120

121
        """
122
        super().__init__(**kwargs)
123
        self.script_path = str(script_path)
124
        if isinstance(script_args, str):
125
            script_args = script_args.split(" ")
126
        self.script_args = script_args if script_args else []
127
        self.original_args = deepcopy(self.script_args)
128
        self.env = env
129
        self.outputs = outputs or []
130
        for name in self.outputs:
131
            setattr(self, name, None)
132
        self.params = None
133
        self.drive = code.get("drive") if code else None
134
        self.code_name = code.get("name") if code else None
135
        self.restart_count = 0
136

137
    def run(
138
        self,
139
        params: Optional[Dict[str, Any]] = None,
140
        restart_count: Optional[int] = None,
141
        code_dir: Optional[str] = ".",
142
        **kwargs: Any,
143
    ):
144
        """
145
        Arguments:
146
            params: A dictionary of arguments to be be added to script_args.
147
            restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks.
148
            code_dir: A path string determining where the source is extracted, default is current directory.
149
        """
150
        if restart_count:
151
            self.restart_count = restart_count
152

153
        if params:
154
            self.params = params
155
            self.script_args = self.original_args + [self._to_script_args(k, v) for k, v in params.items()]
156

157
        if self.drive:
158
            assert self.code_name
159
            if os.path.exists(self.code_name):
160
                clean_tarfile(self.code_name, "r:gz")
161

162
            if self.code_name in self.drive.list():
163
                self.drive.get(self.code_name)
164
                extract_tarfile(self.code_name, code_dir, "r:gz")
165

166
        prev_cwd = os.getcwd()
167
        os.chdir(code_dir)
168

169
        if not os.path.exists(self.script_path):
170
            raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")
171

172
        kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}
173

174
        init_globals = globals()
175
        init_globals.update(kwargs)
176

177
        self.on_before_run()
178
        env_copy = os.environ.copy()
179
        if self.env:
180
            os.environ.update(self.env)
181
        res = self._run_tracer(init_globals)
182
        os.chdir(prev_cwd)
183
        os.environ = env_copy
184
        return self.on_after_run(res)
185

186
    def _run_tracer(self, init_globals):
187
        sys.argv = [self.script_path]
188
        tracer = self.configure_tracer()
189
        return tracer.trace(self.script_path, *self.script_args, init_globals=init_globals)
190

191
    def on_exit(self):
192
        for child_pid in _collect_child_process_pids(os.getpid()):
193
            os.kill(child_pid, signal.SIGTERM)
194

195
    @staticmethod
196
    def _to_script_args(k: str, v: str) -> str:
197
        return f"{k}={v}"
198

199

200
__all__ = ["TracerPythonScript"]
201

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.