pytorch-lightning
152 строки · 5.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16from contextlib import contextmanager
17from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
18
19from deepdiff.helper import NotPresent
20from lightning_utilities.core.apply_func import apply_to_collection
21
22from lightning.app.utilities.app_helpers import is_overridden
23from lightning.app.utilities.enum import ComponentContext
24from lightning.app.utilities.packaging.cloud_compute import CloudCompute
25from lightning.app.utilities.tree import breadth_first
26
27if TYPE_CHECKING:
28from lightning.app.core import LightningFlow
29
30COMPONENT_CONTEXT: Optional[ComponentContext] = None
31
32
33def _convert_paths_after_init(root: "LightningFlow"):
34"""Converts the path attributes on a component to a dictionary.
35
36This is necessary because at the time of instantiating the component, its full affiliation is not known and Paths
37that get passed to other componenets during ``__init__`` are otherwise not able to reference their origin or
38consumer.
39
40"""
41from lightning.app.core import LightningFlow, LightningWork
42from lightning.app.storage.path import Path
43
44for component in breadth_first(root, types=(LightningFlow, LightningWork)):
45for attr in list(component.__dict__.keys()):
46value = getattr(component, attr)
47if isinstance(value, Path):
48delattr(component, attr)
49component._paths[attr] = value.to_dict()
50
51
52def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
53"""Utility function to sanitize the state of a component.
54
55Sanitization enables the state to be deep-copied and hashed.
56
57"""
58from lightning.app.storage import Drive, Path
59from lightning.app.storage.payload import _BasePayload
60
61def sanitize_path(path: Path) -> Path:
62path_copy = Path(path)
63path_copy._sanitize()
64return path_copy
65
66def sanitize_payload(payload: _BasePayload):
67return type(payload).from_dict(content=payload.to_dict())
68
69def sanitize_drive(drive: Drive) -> Dict:
70return drive.to_dict()
71
72def sanitize_cloud_compute(cloud_compute: CloudCompute) -> Dict:
73return cloud_compute.to_dict()
74
75state = apply_to_collection(state, dtype=Path, function=sanitize_path)
76state = apply_to_collection(state, dtype=_BasePayload, function=sanitize_payload)
77state = apply_to_collection(state, dtype=Drive, function=sanitize_drive)
78state = apply_to_collection(state, dtype=CloudCompute, function=sanitize_cloud_compute)
79return state
80
81
82def _state_to_json(state: Dict[str, Any]) -> Dict[str, Any]:
83"""Utility function to make sure that state dict is json serializable."""
84from lightning.app.storage.path import Path
85from lightning.app.storage.payload import _BasePayload
86
87state_paths_cleaned = apply_to_collection(state, dtype=(Path, _BasePayload), function=lambda x: x.to_dict())
88return apply_to_collection(state_paths_cleaned, dtype=type(NotPresent), function=lambda x: None)
89
90
91def _set_context(name: Optional[str]) -> None:
92global COMPONENT_CONTEXT
93COMPONENT_CONTEXT = os.getenv("COMPONENT_CONTEXT") if name is None else ComponentContext(name)
94
95
96def _get_context() -> Optional[ComponentContext]:
97global COMPONENT_CONTEXT
98return COMPONENT_CONTEXT
99
100
101def _set_flow_context() -> None:
102global COMPONENT_CONTEXT
103COMPONENT_CONTEXT = ComponentContext.FLOW
104
105
106def _set_work_context() -> None:
107global COMPONENT_CONTEXT
108COMPONENT_CONTEXT = ComponentContext.WORK
109
110
111def _set_frontend_context() -> None:
112global COMPONENT_CONTEXT
113COMPONENT_CONTEXT = ComponentContext.FRONTEND
114
115
116def _is_flow_context() -> bool:
117global COMPONENT_CONTEXT
118return COMPONENT_CONTEXT == ComponentContext.FLOW
119
120
121def _is_work_context() -> bool:
122global COMPONENT_CONTEXT
123return COMPONENT_CONTEXT == ComponentContext.WORK
124
125
126def _is_frontend_context() -> bool:
127global COMPONENT_CONTEXT
128return COMPONENT_CONTEXT == ComponentContext.FRONTEND
129
130
131@contextmanager
132def _context(ctx: str) -> Generator[None, None, None]:
133"""Set the global component context for the block below this context manager.
134
135The context is used to determine whether the current process is running for a LightningFlow or for a LightningWork.
136See also :func:`_get_context`, :func:`_set_context`. For internal use only.
137
138"""
139prev = _get_context()
140_set_context(ctx)
141yield
142_set_context(prev)
143
144
145def _validate_root_flow(flow: "LightningFlow") -> None:
146from lightning.app.core.flow import LightningFlow
147
148if not is_overridden("run", instance=flow, parent=LightningFlow):
149raise TypeError(
150"The root flow passed to `LightningApp` does not override the `run()` method. This is required. Please"
151f" implement `run()` in your `{flow.__class__.__name__}` class."
152)
153