pytorch-lightning
304 строки · 10.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import inspect
16import os
17import sys
18import traceback
19import types
20from contextlib import contextmanager
21from copy import copy
22from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
23
24from lightning.app.utilities.exceptions import MisconfigurationException
25
26if TYPE_CHECKING:
27from lightning.app.core import LightningApp, LightningFlow, LightningWork
28from lightning.app.plugin.plugin import LightningPlugin
29
30from lightning.app.utilities.app_helpers import Logger, _mock_missing_imports
31
32logger = Logger(__name__)
33
34
35def _prettifiy_exception(filepath: str):
36"""Pretty print the exception that occurred when loading the app."""
37# we want to format the exception as if no frame was on top.
38exp, val, tb = sys.exc_info()
39listing = traceback.format_exception(exp, val, tb)
40# remove the entry for the first frame
41del listing[1]
42listing = [
43f"Found an exception when loading your application from {filepath}. Please, resolve it to run your app.\n\n"
44] + listing
45logger.error("".join(listing))
46sys.exit(1)
47
48
49def _load_objects_from_file(
50filepath: str,
51target_type: Type,
52raise_exception: bool = False,
53mock_imports: bool = False,
54env_vars: Dict[str, str] = {},
55) -> Tuple[List[Any], types.ModuleType]:
56"""Load all of the top-level objects of the given type from a file.
57
58Args:
59filepath: The file to load from.
60target_type: The type of object to load.
61raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit.
62mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to
63be loaded without installing dependencies.
64
65"""
66
67# Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313
68
69# In order for imports to work in a non-package, Python normally adds the current working directory to the
70# system path, not however when running from an entry point like the `lightning` CLI command. So we do it manually:
71with _patch_sys_path(os.path.dirname(os.path.abspath(filepath))):
72code = _create_code(filepath)
73with _create_fake_main_module(filepath) as module:
74try:
75with _add_to_env(env_vars), _patch_sys_argv():
76if mock_imports:
77with _mock_missing_imports():
78exec(code, module.__dict__) # noqa: S102
79else:
80exec(code, module.__dict__) # noqa: S102
81except Exception as ex:
82if raise_exception:
83raise ex
84_prettifiy_exception(filepath)
85
86return [v for v in module.__dict__.values() if isinstance(v, target_type)], module
87
88
89def _load_plugin_from_file(filepath: str) -> "LightningPlugin":
90from lightning.app.plugin.plugin import LightningPlugin
91
92# TODO: Plugin should be run in the context of the created main module here
93plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False)
94
95if len(plugins) > 1:
96raise RuntimeError(f"There should not be multiple plugins instantiated within the file. Found {plugins}")
97if len(plugins) == 1:
98return plugins[0]
99
100raise RuntimeError(f"The provided file {filepath} does not contain a Plugin.")
101
102
103def load_app_from_file(
104filepath: str,
105raise_exception: bool = False,
106mock_imports: bool = False,
107env_vars: Dict[str, str] = {},
108) -> "LightningApp":
109"""Load a LightningApp from a file.
110
111Arguments:
112filepath: The path to the file containing the LightningApp.
113raise_exception: If True, raise an exception if the app cannot be loaded.
114
115"""
116from lightning.app.core.app import LightningApp
117
118apps, main_module = _load_objects_from_file(
119filepath, LightningApp, raise_exception=raise_exception, mock_imports=mock_imports, env_vars=env_vars
120)
121
122# TODO: Remove this, downstream code shouldn't depend on side-effects here but it does
123sys.path.append(os.path.dirname(os.path.abspath(filepath)))
124sys.modules["__main__"] = main_module
125
126if len(apps) > 1:
127raise MisconfigurationException(f"There should not be multiple apps instantiated within a file. Found {apps}")
128if len(apps) == 1:
129return apps[0]
130
131raise MisconfigurationException(
132f"The provided file {filepath} does not contain a LightningApp. Instantiate your app at the module level"
133" like so: `app = LightningApp(flow, ...)`"
134)
135
136
137def _new_module(name):
138"""Create a new module with the given name."""
139return types.ModuleType(name)
140
141
142def open_python_file(filename):
143"""Open a read-only Python file taking proper care of its encoding.
144
145In Python 3, we would like all files to be opened with utf-8 encoding. However, some author like to specify PEP263
146headers in their source files with their own encodings. In that case, we should respect the author's encoding.
147
148"""
149import tokenize
150
151if hasattr(tokenize, "open"): # Added in Python 3.2
152# Open file respecting PEP263 encoding. If no encoding header is
153# found, opens as utf-8.
154return tokenize.open(filename)
155return open(filename, encoding="utf-8") # noqa: SIM115
156
157
158def _create_code(script_path: str):
159with open_python_file(script_path) as f:
160filebody = f.read()
161
162return compile(
163filebody,
164# Pass in the file path so it can show up in exceptions.
165script_path,
166# We're compiling entire blocks of Python, so we need "exec"
167# mode (as opposed to "eval" or "single").
168mode="exec",
169# Don't inherit any flags or "future" statements.
170flags=0,
171dont_inherit=1,
172# Use the default optimization options.
173optimize=-1,
174)
175
176
177@contextmanager
178def _create_fake_main_module(script_path):
179# Create fake module. This gives us a name global namespace to
180# execute the code in.
181module = _new_module("__main__")
182
183# Install the fake module as the __main__ module. This allows
184# the pickle module to work inside the user's code, since it now
185# can know the module where the pickled objects stem from.
186# IMPORTANT: This means we can't use "if __name__ == '__main__'" in
187# our code, as it will point to the wrong module!!!
188old_main_module = sys.modules["__main__"]
189sys.modules["__main__"] = module
190
191# Add special variables to the module's globals dict.
192# Note: The following is a requirement for the CodeHasher to
193# work correctly. The CodeHasher is scoped to
194# files contained in the directory of __main__.__file__, which we
195# assume is the main script directory.
196module.__dict__["__file__"] = os.path.abspath(script_path)
197
198try:
199yield module
200finally:
201sys.modules["__main__"] = old_main_module
202
203
204@contextmanager
205def _patch_sys_path(append):
206"""A context manager that appends the given value to the path once entered.
207
208Args:
209append: The value to append to the path.
210
211"""
212if append in sys.path:
213yield
214return
215
216sys.path.append(append)
217
218try:
219yield
220finally:
221sys.path.remove(append)
222
223
224@contextmanager
225def _add_to_env(envs: Dict[str, str]):
226"""This function adds the given environment variables to the current environment."""
227original_envs = dict(os.environ)
228os.environ.update(envs)
229
230try:
231yield
232finally:
233os.environ.clear()
234os.environ.update(original_envs)
235
236
237@contextmanager
238def _patch_sys_argv():
239"""This function modifies the ``sys.argv`` by extracting the arguments after ``--app_args`` and removed everything
240else before executing the user app script.
241
242The command: ``lightning_app run app app.py --without-server --app_args --use_gpu --env ...`` will be converted into
243``app.py --use_gpu``
244
245"""
246from lightning.app.cli.lightning_cli import run_app
247
248original_argv = copy(sys.argv)
249# 1: Remove the CLI command
250if sys.argv[:3] == ["lightning", "run", "app"]:
251sys.argv = sys.argv[3:]
252
253if "--app_args" not in sys.argv:
254# 2: If app_args wasn't used, there is no arguments, so we assign the shorten arguments.
255new_argv = sys.argv[:1]
256else:
257# 3: Collect all the arguments from the CLI
258options = [p.opts[0] for p in run_app.params[1:] if p.opts[0] != "--app_args"]
259argv_slice = sys.argv
260# 4: Find the index of `app_args`
261first_index = argv_slice.index("--app_args") + 1
262# 5: Find the next argument from the CLI if any.
263matches = [
264argv_slice.index(opt) for opt in options if opt in argv_slice and argv_slice.index(opt) >= first_index
265]
266last_index = len(argv_slice) if not matches else min(matches)
267# 6: last_index is either the fully command or the latest match from the CLI options.
268new_argv = [argv_slice[0]] + argv_slice[first_index:last_index]
269
270# 7: Patch the command
271sys.argv = new_argv
272
273try:
274yield
275finally:
276# 8: Restore the command
277sys.argv = original_argv
278
279
280def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict:
281from lightning.app.core import LightningWork
282
283extras = {}
284
285if isinstance(obj, LightningWork):
286extras = {
287"local_build_config": obj.local_build_config.to_dict(),
288"cloud_build_config": obj.cloud_build_config.to_dict(),
289"cloud_compute": obj.cloud_compute.to_dict(),
290}
291
292return dict(
293affiliation=obj.name.split("."),
294cls_name=obj.__class__.__name__,
295module=obj.__module__,
296docstring=inspect.getdoc(obj.__init__),
297**extras,
298)
299
300
301def extract_metadata_from_app(app: "LightningApp") -> List:
302metadata = {flow.name: component_to_metadata(flow) for flow in app.flows}
303metadata.update({work.name: component_to_metadata(work) for work in app.works})
304return [metadata[key] for key in sorted(metadata.keys())]
305