pytorch-lightning

Форк
0
152 строки · 5.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
from contextlib import contextmanager
17
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
18

19
from deepdiff.helper import NotPresent
20
from lightning_utilities.core.apply_func import apply_to_collection
21

22
from lightning.app.utilities.app_helpers import is_overridden
23
from lightning.app.utilities.enum import ComponentContext
24
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
25
from lightning.app.utilities.tree import breadth_first
26

27
if TYPE_CHECKING:
28
    from lightning.app.core import LightningFlow
29

30
COMPONENT_CONTEXT: Optional[ComponentContext] = None
31

32

33
def _convert_paths_after_init(root: "LightningFlow"):
34
    """Converts the path attributes on a component to a dictionary.
35

36
    This is necessary because at the time of instantiating the component, its full affiliation is not known and Paths
37
    that get passed to other componenets during ``__init__`` are otherwise not able to reference their origin or
38
    consumer.
39

40
    """
41
    from lightning.app.core import LightningFlow, LightningWork
42
    from lightning.app.storage.path import Path
43

44
    for component in breadth_first(root, types=(LightningFlow, LightningWork)):
45
        for attr in list(component.__dict__.keys()):
46
            value = getattr(component, attr)
47
            if isinstance(value, Path):
48
                delattr(component, attr)
49
                component._paths[attr] = value.to_dict()
50

51

52
def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
53
    """Utility function to sanitize the state of a component.
54

55
    Sanitization enables the state to be deep-copied and hashed.
56

57
    """
58
    from lightning.app.storage import Drive, Path
59
    from lightning.app.storage.payload import _BasePayload
60

61
    def sanitize_path(path: Path) -> Path:
62
        path_copy = Path(path)
63
        path_copy._sanitize()
64
        return path_copy
65

66
    def sanitize_payload(payload: _BasePayload):
67
        return type(payload).from_dict(content=payload.to_dict())
68

69
    def sanitize_drive(drive: Drive) -> Dict:
70
        return drive.to_dict()
71

72
    def sanitize_cloud_compute(cloud_compute: CloudCompute) -> Dict:
73
        return cloud_compute.to_dict()
74

75
    state = apply_to_collection(state, dtype=Path, function=sanitize_path)
76
    state = apply_to_collection(state, dtype=_BasePayload, function=sanitize_payload)
77
    state = apply_to_collection(state, dtype=Drive, function=sanitize_drive)
78
    state = apply_to_collection(state, dtype=CloudCompute, function=sanitize_cloud_compute)
79
    return state
80

81

82
def _state_to_json(state: Dict[str, Any]) -> Dict[str, Any]:
83
    """Utility function to make sure that state dict is json serializable."""
84
    from lightning.app.storage.path import Path
85
    from lightning.app.storage.payload import _BasePayload
86

87
    state_paths_cleaned = apply_to_collection(state, dtype=(Path, _BasePayload), function=lambda x: x.to_dict())
88
    return apply_to_collection(state_paths_cleaned, dtype=type(NotPresent), function=lambda x: None)
89

90

91
def _set_context(name: Optional[str]) -> None:
92
    global COMPONENT_CONTEXT
93
    COMPONENT_CONTEXT = os.getenv("COMPONENT_CONTEXT") if name is None else ComponentContext(name)
94

95

96
def _get_context() -> Optional[ComponentContext]:
97
    global COMPONENT_CONTEXT
98
    return COMPONENT_CONTEXT
99

100

101
def _set_flow_context() -> None:
102
    global COMPONENT_CONTEXT
103
    COMPONENT_CONTEXT = ComponentContext.FLOW
104

105

106
def _set_work_context() -> None:
107
    global COMPONENT_CONTEXT
108
    COMPONENT_CONTEXT = ComponentContext.WORK
109

110

111
def _set_frontend_context() -> None:
112
    global COMPONENT_CONTEXT
113
    COMPONENT_CONTEXT = ComponentContext.FRONTEND
114

115

116
def _is_flow_context() -> bool:
117
    global COMPONENT_CONTEXT
118
    return COMPONENT_CONTEXT == ComponentContext.FLOW
119

120

121
def _is_work_context() -> bool:
122
    global COMPONENT_CONTEXT
123
    return COMPONENT_CONTEXT == ComponentContext.WORK
124

125

126
def _is_frontend_context() -> bool:
127
    global COMPONENT_CONTEXT
128
    return COMPONENT_CONTEXT == ComponentContext.FRONTEND
129

130

131
@contextmanager
132
def _context(ctx: str) -> Generator[None, None, None]:
133
    """Set the global component context for the block below this context manager.
134

135
    The context is used to determine whether the current process is running for a LightningFlow or for a LightningWork.
136
    See also :func:`_get_context`, :func:`_set_context`. For internal use only.
137

138
    """
139
    prev = _get_context()
140
    _set_context(ctx)
141
    yield
142
    _set_context(prev)
143

144

145
def _validate_root_flow(flow: "LightningFlow") -> None:
146
    from lightning.app.core.flow import LightningFlow
147

148
    if not is_overridden("run", instance=flow, parent=LightningFlow):
149
        raise TypeError(
150
            "The root flow passed to `LightningApp` does not override the `run()` method. This is required. Please"
151
            f" implement `run()` in your `{flow.__class__.__name__}` class."
152
        )
153

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.