pytorch-lightning

Форк
0
304 строки · 10.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import inspect
16
import os
17
import sys
18
import traceback
19
import types
20
from contextlib import contextmanager
21
from copy import copy
22
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
23

24
from lightning.app.utilities.exceptions import MisconfigurationException
25

26
if TYPE_CHECKING:
27
    from lightning.app.core import LightningApp, LightningFlow, LightningWork
28
    from lightning.app.plugin.plugin import LightningPlugin
29

30
from lightning.app.utilities.app_helpers import Logger, _mock_missing_imports
31

32
logger = Logger(__name__)
33

34

35
def _prettifiy_exception(filepath: str):
36
    """Pretty print the exception that occurred when loading the app."""
37
    # we want to format the exception as if no frame was on top.
38
    exp, val, tb = sys.exc_info()
39
    listing = traceback.format_exception(exp, val, tb)
40
    # remove the entry for the first frame
41
    del listing[1]
42
    listing = [
43
        f"Found an exception when loading your application from {filepath}. Please, resolve it to run your app.\n\n"
44
    ] + listing
45
    logger.error("".join(listing))
46
    sys.exit(1)
47

48

49
def _load_objects_from_file(
50
    filepath: str,
51
    target_type: Type,
52
    raise_exception: bool = False,
53
    mock_imports: bool = False,
54
    env_vars: Dict[str, str] = {},
55
) -> Tuple[List[Any], types.ModuleType]:
56
    """Load all of the top-level objects of the given type from a file.
57

58
    Args:
59
        filepath: The file to load from.
60
        target_type: The type of object to load.
61
        raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit.
62
        mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to
63
            be loaded without installing dependencies.
64

65
    """
66

67
    # Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313
68

69
    # In order for imports to work in a non-package, Python normally adds the current working directory to the
70
    # system path, not however when running from an entry point like the `lightning` CLI command. So we do it manually:
71
    with _patch_sys_path(os.path.dirname(os.path.abspath(filepath))):
72
        code = _create_code(filepath)
73
        with _create_fake_main_module(filepath) as module:
74
            try:
75
                with _add_to_env(env_vars), _patch_sys_argv():
76
                    if mock_imports:
77
                        with _mock_missing_imports():
78
                            exec(code, module.__dict__)  # noqa: S102
79
                    else:
80
                        exec(code, module.__dict__)  # noqa: S102
81
            except Exception as ex:
82
                if raise_exception:
83
                    raise ex
84
                _prettifiy_exception(filepath)
85

86
    return [v for v in module.__dict__.values() if isinstance(v, target_type)], module
87

88

89
def _load_plugin_from_file(filepath: str) -> "LightningPlugin":
90
    from lightning.app.plugin.plugin import LightningPlugin
91

92
    # TODO: Plugin should be run in the context of the created main module here
93
    plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False)
94

95
    if len(plugins) > 1:
96
        raise RuntimeError(f"There should not be multiple plugins instantiated within the file. Found {plugins}")
97
    if len(plugins) == 1:
98
        return plugins[0]
99

100
    raise RuntimeError(f"The provided file {filepath} does not contain a Plugin.")
101

102

103
def load_app_from_file(
104
    filepath: str,
105
    raise_exception: bool = False,
106
    mock_imports: bool = False,
107
    env_vars: Dict[str, str] = {},
108
) -> "LightningApp":
109
    """Load a LightningApp from a file.
110

111
    Arguments:
112
        filepath:  The path to the file containing the LightningApp.
113
        raise_exception: If True, raise an exception if the app cannot be loaded.
114

115
    """
116
    from lightning.app.core.app import LightningApp
117

118
    apps, main_module = _load_objects_from_file(
119
        filepath, LightningApp, raise_exception=raise_exception, mock_imports=mock_imports, env_vars=env_vars
120
    )
121

122
    # TODO: Remove this, downstream code shouldn't depend on side-effects here but it does
123
    sys.path.append(os.path.dirname(os.path.abspath(filepath)))
124
    sys.modules["__main__"] = main_module
125

126
    if len(apps) > 1:
127
        raise MisconfigurationException(f"There should not be multiple apps instantiated within a file. Found {apps}")
128
    if len(apps) == 1:
129
        return apps[0]
130

131
    raise MisconfigurationException(
132
        f"The provided file {filepath} does not contain a LightningApp. Instantiate your app at the module level"
133
        " like so: `app = LightningApp(flow, ...)`"
134
    )
135

136

137
def _new_module(name):
138
    """Create a new module with the given name."""
139
    return types.ModuleType(name)
140

141

142
def open_python_file(filename):
143
    """Open a read-only Python file taking proper care of its encoding.
144

145
    In Python 3, we would like all files to be opened with utf-8 encoding. However, some author like to specify PEP263
146
    headers in their source files with their own encodings. In that case, we should respect the author's encoding.
147

148
    """
149
    import tokenize
150

151
    if hasattr(tokenize, "open"):  # Added in Python 3.2
152
        # Open file respecting PEP263 encoding. If no encoding header is
153
        # found, opens as utf-8.
154
        return tokenize.open(filename)
155
    return open(filename, encoding="utf-8")  # noqa: SIM115
156

157

158
def _create_code(script_path: str):
159
    with open_python_file(script_path) as f:
160
        filebody = f.read()
161

162
    return compile(
163
        filebody,
164
        # Pass in the file path so it can show up in exceptions.
165
        script_path,
166
        # We're compiling entire blocks of Python, so we need "exec"
167
        # mode (as opposed to "eval" or "single").
168
        mode="exec",
169
        # Don't inherit any flags or "future" statements.
170
        flags=0,
171
        dont_inherit=1,
172
        # Use the default optimization options.
173
        optimize=-1,
174
    )
175

176

177
@contextmanager
178
def _create_fake_main_module(script_path):
179
    # Create fake module. This gives us a name global namespace to
180
    # execute the code in.
181
    module = _new_module("__main__")
182

183
    # Install the fake module as the __main__ module. This allows
184
    # the pickle module to work inside the user's code, since it now
185
    # can know the module where the pickled objects stem from.
186
    # IMPORTANT: This means we can't use "if __name__ == '__main__'" in
187
    # our code, as it will point to the wrong module!!!
188
    old_main_module = sys.modules["__main__"]
189
    sys.modules["__main__"] = module
190

191
    # Add special variables to the module's globals dict.
192
    # Note: The following is a requirement for the CodeHasher to
193
    # work correctly. The CodeHasher is scoped to
194
    # files contained in the directory of __main__.__file__, which we
195
    # assume is the main script directory.
196
    module.__dict__["__file__"] = os.path.abspath(script_path)
197

198
    try:
199
        yield module
200
    finally:
201
        sys.modules["__main__"] = old_main_module
202

203

204
@contextmanager
205
def _patch_sys_path(append):
206
    """A context manager that appends the given value to the path once entered.
207

208
    Args:
209
        append: The value to append to the path.
210

211
    """
212
    if append in sys.path:
213
        yield
214
        return
215

216
    sys.path.append(append)
217

218
    try:
219
        yield
220
    finally:
221
        sys.path.remove(append)
222

223

224
@contextmanager
225
def _add_to_env(envs: Dict[str, str]):
226
    """This function adds the given environment variables to the current environment."""
227
    original_envs = dict(os.environ)
228
    os.environ.update(envs)
229

230
    try:
231
        yield
232
    finally:
233
        os.environ.clear()
234
        os.environ.update(original_envs)
235

236

237
@contextmanager
238
def _patch_sys_argv():
239
    """This function modifies the ``sys.argv`` by extracting the arguments after ``--app_args`` and removed everything
240
    else before executing the user app script.
241

242
    The command: ``lightning_app run app app.py --without-server --app_args --use_gpu --env ...`` will be converted into
243
    ``app.py --use_gpu``
244

245
    """
246
    from lightning.app.cli.lightning_cli import run_app
247

248
    original_argv = copy(sys.argv)
249
    # 1: Remove the CLI command
250
    if sys.argv[:3] == ["lightning", "run", "app"]:
251
        sys.argv = sys.argv[3:]
252

253
    if "--app_args" not in sys.argv:
254
        # 2: If app_args wasn't used, there is no arguments, so we assign the shorten arguments.
255
        new_argv = sys.argv[:1]
256
    else:
257
        # 3: Collect all the arguments from the CLI
258
        options = [p.opts[0] for p in run_app.params[1:] if p.opts[0] != "--app_args"]
259
        argv_slice = sys.argv
260
        # 4: Find the index of `app_args`
261
        first_index = argv_slice.index("--app_args") + 1
262
        # 5: Find the next argument from the CLI if any.
263
        matches = [
264
            argv_slice.index(opt) for opt in options if opt in argv_slice and argv_slice.index(opt) >= first_index
265
        ]
266
        last_index = len(argv_slice) if not matches else min(matches)
267
        # 6: last_index is either the fully command or the latest match from the CLI options.
268
        new_argv = [argv_slice[0]] + argv_slice[first_index:last_index]
269

270
    # 7: Patch the command
271
    sys.argv = new_argv
272

273
    try:
274
        yield
275
    finally:
276
        # 8: Restore the command
277
        sys.argv = original_argv
278

279

280
def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict:
281
    from lightning.app.core import LightningWork
282

283
    extras = {}
284

285
    if isinstance(obj, LightningWork):
286
        extras = {
287
            "local_build_config": obj.local_build_config.to_dict(),
288
            "cloud_build_config": obj.cloud_build_config.to_dict(),
289
            "cloud_compute": obj.cloud_compute.to_dict(),
290
        }
291

292
    return dict(
293
        affiliation=obj.name.split("."),
294
        cls_name=obj.__class__.__name__,
295
        module=obj.__module__,
296
        docstring=inspect.getdoc(obj.__init__),
297
        **extras,
298
    )
299

300

301
def extract_metadata_from_app(app: "LightningApp") -> List:
302
    metadata = {flow.name: component_to_metadata(flow) for flow in app.flows}
303
    metadata.update({work.name: component_to_metadata(work) for work in app.works})
304
    return [metadata[key] for key in sorted(metadata.keys())]
305

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.