pytorch-lightning
160 строк · 6.1 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import typing as t
16
17from lightning.app.utilities.app_helpers import _LightningAppRef, _set_child_name
18
19T = t.TypeVar("T")
20
21if t.TYPE_CHECKING:
22from lightning.app.utilities.types import Component
23
24
25def _prepare_name(component: "Component") -> str:
26return str(component.name.split(".")[-1])
27
28
29# TODO: add support and tests for dict operations (insertion, update, etc.)
30class Dict(t.Dict[str, T]):
31def __init__(self, **kwargs: T):
32"""The Dict Object is used to represents dict collection of :class:`~lightning.app.core.work.LightningWork` or
33:class:`~lightning.app.core.flow.LightningFlow`.
34
35Example:
36
37>>> from lightning.app import LightningFlow, LightningWork
38>>> from lightning.app.structures import Dict
39>>> class CounterWork(LightningWork):
40... def __init__(self):
41... super().__init__()
42... self.counter = 0
43... def run(self):
44... self.counter += 1
45...
46>>> class RootFlow(LightningFlow):
47... def __init__(self):
48... super().__init__()
49... self.dict = Dict(**{"work_0": CounterWork(), "work_1": CounterWork()})
50... def run(self):
51... for work_name, work in self.dict.items():
52... work.run()
53...
54>>> flow = RootFlow()
55>>> flow.run()
56>>> assert flow.dict["work_0"].counter == 1
57
58Arguments:
59items: A sequence of LightningWork or LightningFlow.
60
61"""
62super().__init__(**kwargs)
63from lightning.app.runners.backends import Backend
64
65self._name: t.Optional[str] = ""
66self._backend: t.Optional[Backend] = None
67for k, v in kwargs.items():
68if "." in k:
69raise Exception(f"The provided name {k} contains . which is forbidden.")
70_set_child_name(self, v, k)
71
72def __setitem__(self, k, v):
73from lightning.app.core import LightningFlow, LightningWork
74
75if not isinstance(k, str):
76raise Exception("The provided key should be an string")
77
78if isinstance(k, str) and "." in k:
79raise Exception(f"The provided name {k} contains . which is forbidden.")
80
81_set_child_name(self, v, k)
82if self._backend:
83if isinstance(v, LightningFlow):
84LightningFlow._attach_backend(v, self._backend)
85elif isinstance(v, LightningWork):
86self._backend._wrap_run_method(_LightningAppRef().get_current(), v)
87v._name = f"{self.name}.{k}"
88super().__setitem__(k, v)
89
90@property
91def works(self):
92from lightning.app.core import LightningFlow, LightningWork
93
94works = [item for item in self.values() if isinstance(item, LightningWork)]
95for flow in [item for item in self.values() if isinstance(item, LightningFlow)]:
96for child_work in flow.works(recurse=False):
97works.append(child_work)
98return works
99
100@property
101def flows(self):
102from lightning.app.core.flow import LightningFlow
103from lightning.app.structures import Dict as _Dict
104from lightning.app.structures import List as _List
105
106flows = {}
107for item in self.values():
108if isinstance(item, LightningFlow):
109flows[item.name] = item
110for child_flow in item.flows.values():
111flows[child_flow.name] = child_flow
112if isinstance(item, (_Dict, _List)):
113for child_flow in item.flows.values():
114flows[child_flow.name] = child_flow
115return flows
116
117@property
118def name(self):
119return self._name or "root"
120
121@property
122def state(self):
123"""Returns the state of its flows and works."""
124from lightning.app.core import LightningFlow, LightningWork
125
126return {
127"works": {key: item.state for key, item in self.items() if isinstance(item, LightningWork)},
128"flows": {key: item.state for key, item in self.items() if isinstance(item, LightningFlow)},
129}
130
131@property
132def state_vars(self):
133from lightning.app.core import LightningFlow, LightningWork
134
135return {
136"works": {key: item.state_vars for key, item in self.items() if isinstance(item, LightningWork)},
137"flows": {key: item.state_vars for key, item in self.items() if isinstance(item, LightningFlow)},
138}
139
140@property
141def state_with_changes(self):
142from lightning.app.core import LightningFlow, LightningWork
143
144return {
145"works": {key: item.state_with_changes for key, item in self.items() if isinstance(item, LightningWork)},
146"flows": {key: item.state_with_changes for key, item in self.items() if isinstance(item, LightningFlow)},
147}
148
149def set_state(self, state):
150state_keys = set(list(state["works"].keys()) + list(state["flows"].keys()))
151current_state_keys = set(self.keys())
152if current_state_keys != state_keys:
153key_diff = (current_state_keys - state_keys) | (state_keys - current_state_keys)
154raise Exception(
155f"The provided state doesn't match the `Dict` {self.name}. Found `{key_diff}` un-matching keys"
156)
157for work_key, work_state in state["works"].items():
158self[work_key].set_state(work_state)
159for child_key, child_state in state["flows"].items():
160self[child_key].set_state(child_state)
161