pytorch-lightning

Форк
0
107 строк · 4.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import contextlib
16
import pickle
17
import sys
18
import types
19
import typing
20
from copy import deepcopy
21
from pathlib import Path
22

23
from lightning.app.core.work import LightningWork
24
from lightning.app.utilities.app_helpers import _LightningAppRef
25

26
NON_PICKLABLE_WORK_ATTRIBUTES = ["_request_queue", "_response_queue", "_backend", "_setattr_replacement"]
27

28

29
@contextlib.contextmanager
30
def _trimmed_work(work: LightningWork, to_trim: typing.List[str]) -> typing.Iterator[None]:
31
    """Context manager to trim the work object to remove attributes that are not picklable."""
32
    holder = {}
33
    for arg in to_trim:
34
        holder[arg] = getattr(work, arg)
35
        setattr(work, arg, None)
36
    yield
37
    for arg in to_trim:
38
        setattr(work, arg, holder[arg])
39

40

41
def get_picklable_work(work: LightningWork) -> LightningWork:
42
    """Pickling a LightningWork instance fails if done from the work process
43
    itself. This function is safe to call from the work process within both MultiprocessRuntime
44
    and Cloud.
45
    Note: This function modifies the module information of the work object. Specifically, it injects
46
    the relative module path into the __module__ attribute of the work object. If the object is not
47
    importable from the CWD, then the pickle load will fail.
48

49
    Example:
50
        for a directory structure like below and the work class is defined in the app.py where
51
        the app.py is the entrypoint for the app, it will inject `foo.bar.app` into the
52
        __module__ attribute
53

54
        └── foo
55
            ├── __init__.py
56
            └── bar
57
                └── app.py
58
    """
59
    # If the work object not taken from the app ref, there is a thread lock reference
60
    # somewhere thats preventing it from being pickled. Investigate it later. We
61
    # shouldn't be fetching the work object from the app ref. TODO @sherin
62
    app_ref = _LightningAppRef.get_current()
63
    if app_ref is None:
64
        raise RuntimeError("Cannot pickle LightningWork outside of a LightningApp")
65
    for w in app_ref.works:
66
        if work.name == w.name:
67
            # deep-copying the work object to avoid modifying the original work object
68
            with _trimmed_work(w, to_trim=NON_PICKLABLE_WORK_ATTRIBUTES):
69
                copied_work = deepcopy(w)
70
            break
71
    else:
72
        raise ValueError(f"Work with name {work.name} not found in the app references")
73

74
    # if work is defined in the __main__ or __mp__main__ (the entrypoint file for `lightning run app` command),
75
    # pickling/unpickling will fail, hence we need patch the module information
76
    if "_main__" in copied_work.__class__.__module__:
77
        work_class_module = sys.modules[copied_work.__class__.__module__]
78
        work_class_file = work_class_module.__file__
79
        if not work_class_file:
80
            raise ValueError(
81
                f"Cannot pickle work class {copied_work.__class__.__name__} because we "
82
                f"couldn't identify the module file"
83
            )
84
        relative_path = Path(work_class_module.__file__).relative_to(Path.cwd())  # type: ignore
85
        expected_module_name = relative_path.as_posix().replace(".py", "").replace("/", ".")
86
        # TODO @sherin: also check if the module is importable from the CWD
87
        fake_module = types.ModuleType(expected_module_name)
88
        fake_module.__dict__.update(work_class_module.__dict__)
89
        fake_module.__dict__["__name__"] = expected_module_name
90
        sys.modules[expected_module_name] = fake_module
91
        for k, v in fake_module.__dict__.items():
92
            if not k.startswith("__") and hasattr(v, "__module__") and "_main__" in v.__module__:
93
                v.__module__ = expected_module_name
94
    return copied_work
95

96

97
def dump(work: LightningWork, f: typing.BinaryIO) -> None:
98
    picklable_work = get_picklable_work(work)
99
    pickle.dump(picklable_work, f)
100

101

102
def load(f: typing.BinaryIO) -> typing.Any:
103
    # inject current working directory to sys.path
104
    sys.path.insert(1, str(Path.cwd()))
105
    work = pickle.load(f)
106
    sys.path.pop(1)
107
    return work
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.