pytorch-lightning
177 строк · 6.6 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import typing as t16
17from lightning.app.utilities.app_helpers import _LightningAppRef, _set_child_name18
19T = t.TypeVar("T")20
21if t.TYPE_CHECKING:22from lightning.app.utilities.types import Component23
24
25def _prepare_name(component: "Component") -> str:26return str(component.name.split(".")[-1])27
28
29# TODO: add support and tests for list operations (concatenation, deletion, insertion, etc.)
30class List(t.List[T]):31def __init__(self, *items: T):32"""The List Object is used to represents list collection of :class:`~lightning.app.core.work.LightningWork` or33:class:`~lightning.app.core.flow.LightningFlow`.
34
35Example:
36
37>>> from lightning.app import LightningFlow, LightningWork
38>>> from lightning.app.structures import List
39>>> class CounterWork(LightningWork):
40... def __init__(self):
41... super().__init__()
42... self.counter = 0
43... def run(self):
44... self.counter += 1
45...
46>>> class RootFlow(LightningFlow):
47... def __init__(self):
48... super().__init__()
49... self.list = List(*[CounterWork(), CounterWork()])
50... def run(self):
51... for work in self.list:
52... work.run()
53...
54>>> flow = RootFlow()
55>>> flow.run()
56>>> assert flow.list[0].counter == 1
57
58Arguments:
59items: A sequence of LightningWork or LightningFlow.
60
61"""
62super().__init__()63from lightning.app.runners.backends import Backend64
65self._name: t.Optional[str] = ""66self._last_index = 067self._backend: t.Optional[Backend] = None68for item in items:69self.append(item)70
71def append(self, v):72from lightning.app.core import LightningFlow, LightningWork73
74_set_child_name(self, v, str(self._last_index))75if self._backend:76if isinstance(v, LightningFlow):77LightningFlow._attach_backend(v, self._backend)78elif isinstance(v, LightningWork):79self._backend._wrap_run_method(_LightningAppRef().get_current(), v)80v._name = f"{self.name}.{self._last_index}"81self._last_index += 182super().append(v)83
84@property85def name(self):86"""Returns the name of this List object."""87return self._name or "root"88
89@property90def works(self):91from lightning.app.core import LightningFlow, LightningWork92
93works = [item for item in self if isinstance(item, LightningWork)]94for flow in [item for item in self if isinstance(item, LightningFlow)]:95for child_work in flow.works(recurse=False):96works.append(child_work)97return works98
99@property100def flows(self):101from lightning.app.core import LightningFlow102from lightning.app.structures import Dict as _Dict103from lightning.app.structures import List as _List104
105flows = {}106for item in self:107if isinstance(item, LightningFlow):108flows[item.name] = item109for child_flow in item.flows.values():110flows[child_flow.name] = child_flow111if isinstance(item, (_Dict, _List)):112for child_flow in item.flows.values():113flows[child_flow.name] = child_flow114return flows115
116@property117def state(self):118"""Returns the state of its flows and works."""119from lightning.app.core import LightningFlow, LightningWork120
121works = [item for item in self if isinstance(item, LightningWork)]122children = [item for item in self if isinstance(item, LightningFlow)]123return {124"works": {_prepare_name(w): w.state for w in works},125"flows": {_prepare_name(flow): flow.state for flow in children},126}127
128@property129def state_vars(self):130from lightning.app.core import LightningFlow, LightningWork131
132works = [item for item in self if isinstance(item, LightningWork)]133children = [item for item in self if isinstance(item, LightningFlow)]134return {135"works": {_prepare_name(w): w.state_vars for w in works},136"flows": {_prepare_name(flow): flow.state_vars for flow in children},137}138
139@property140def state_with_changes(self):141from lightning.app.core import LightningFlow, LightningWork142
143works = [item for item in self if isinstance(item, LightningWork)]144children = [item for item in self if isinstance(item, LightningFlow)]145return {146"works": {str(_prepare_name(w)): w.state_with_changes for w in works},147"flows": {_prepare_name(flow): flow.state_with_changes for flow in children},148}149
150def set_state(self, state):151"""Method to set the state of the list and its children."""152from lightning.app.core import LightningFlow, LightningWork153
154works = [item for item in self if isinstance(item, LightningWork)]155children = [item for item in self if isinstance(item, LightningFlow)]156
157current_state_keys = {_prepare_name(w) for w in self}158state_keys = set(list(state["works"].keys()) + list(state["flows"].keys()))159
160if current_state_keys != state_keys:161key_diff = (current_state_keys - state_keys) | (state_keys - current_state_keys)162raise Exception(163f"The provided state doesn't match the `List` {self.name}. Found `{key_diff}` un-matching keys"164)165
166for work_key, work_state in state["works"].items():167for work in works:168if _prepare_name(work) == work_key:169work.set_state(work_state)170for child_key, child_state in state["flows"].items():171for child in children:172if _prepare_name(child) == child_key:173child.set_state(child_state)174
175def __len__(self):176"""Returns the number of elements within this List."""177return sum(1 for _ in self)178