pytorch-lightning
107 строк · 4.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import contextlib16import pickle17import sys18import types19import typing20from copy import deepcopy21from pathlib import Path22
23from lightning.app.core.work import LightningWork24from lightning.app.utilities.app_helpers import _LightningAppRef25
26NON_PICKLABLE_WORK_ATTRIBUTES = ["_request_queue", "_response_queue", "_backend", "_setattr_replacement"]27
28
29@contextlib.contextmanager30def _trimmed_work(work: LightningWork, to_trim: typing.List[str]) -> typing.Iterator[None]:31"""Context manager to trim the work object to remove attributes that are not picklable."""32holder = {}33for arg in to_trim:34holder[arg] = getattr(work, arg)35setattr(work, arg, None)36yield37for arg in to_trim:38setattr(work, arg, holder[arg])39
40
41def get_picklable_work(work: LightningWork) -> LightningWork:42"""Pickling a LightningWork instance fails if done from the work process43itself. This function is safe to call from the work process within both MultiprocessRuntime
44and Cloud.
45Note: This function modifies the module information of the work object. Specifically, it injects
46the relative module path into the __module__ attribute of the work object. If the object is not
47importable from the CWD, then the pickle load will fail.
48
49Example:
50for a directory structure like below and the work class is defined in the app.py where
51the app.py is the entrypoint for the app, it will inject `foo.bar.app` into the
52__module__ attribute
53
54└── foo
55├── __init__.py
56└── bar
57└── app.py
58"""
59# If the work object not taken from the app ref, there is a thread lock reference60# somewhere thats preventing it from being pickled. Investigate it later. We61# shouldn't be fetching the work object from the app ref. TODO @sherin62app_ref = _LightningAppRef.get_current()63if app_ref is None:64raise RuntimeError("Cannot pickle LightningWork outside of a LightningApp")65for w in app_ref.works:66if work.name == w.name:67# deep-copying the work object to avoid modifying the original work object68with _trimmed_work(w, to_trim=NON_PICKLABLE_WORK_ATTRIBUTES):69copied_work = deepcopy(w)70break71else:72raise ValueError(f"Work with name {work.name} not found in the app references")73
74# if work is defined in the __main__ or __mp__main__ (the entrypoint file for `lightning run app` command),75# pickling/unpickling will fail, hence we need patch the module information76if "_main__" in copied_work.__class__.__module__:77work_class_module = sys.modules[copied_work.__class__.__module__]78work_class_file = work_class_module.__file__79if not work_class_file:80raise ValueError(81f"Cannot pickle work class {copied_work.__class__.__name__} because we "82f"couldn't identify the module file"83)84relative_path = Path(work_class_module.__file__).relative_to(Path.cwd()) # type: ignore85expected_module_name = relative_path.as_posix().replace(".py", "").replace("/", ".")86# TODO @sherin: also check if the module is importable from the CWD87fake_module = types.ModuleType(expected_module_name)88fake_module.__dict__.update(work_class_module.__dict__)89fake_module.__dict__["__name__"] = expected_module_name90sys.modules[expected_module_name] = fake_module91for k, v in fake_module.__dict__.items():92if not k.startswith("__") and hasattr(v, "__module__") and "_main__" in v.__module__:93v.__module__ = expected_module_name94return copied_work95
96
97def dump(work: LightningWork, f: typing.BinaryIO) -> None:98picklable_work = get_picklable_work(work)99pickle.dump(picklable_work, f)100
101
102def load(f: typing.BinaryIO) -> typing.Any:103# inject current working directory to sys.path104sys.path.insert(1, str(Path.cwd()))105work = pickle.load(f)106sys.path.pop(1)107return work108