pytorch-lightning

Форк
0
177 строк · 6.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import typing as t
16

17
from lightning.app.utilities.app_helpers import _LightningAppRef, _set_child_name
18

19
T = t.TypeVar("T")
20

21
if t.TYPE_CHECKING:
22
    from lightning.app.utilities.types import Component
23

24

25
def _prepare_name(component: "Component") -> str:
26
    return str(component.name.split(".")[-1])
27

28

29
# TODO: add support and tests for list operations (concatenation, deletion, insertion, etc.)
30
class List(t.List[T]):
31
    def __init__(self, *items: T):
32
        """The List Object is used to represents list collection of :class:`~lightning.app.core.work.LightningWork` or
33
        :class:`~lightning.app.core.flow.LightningFlow`.
34

35
        Example:
36

37
            >>> from lightning.app import LightningFlow, LightningWork
38
            >>> from lightning.app.structures import List
39
            >>> class CounterWork(LightningWork):
40
            ...     def __init__(self):
41
            ...         super().__init__()
42
            ...         self.counter = 0
43
            ...     def run(self):
44
            ...         self.counter += 1
45
            ...
46
            >>> class RootFlow(LightningFlow):
47
            ...     def __init__(self):
48
            ...         super().__init__()
49
            ...         self.list = List(*[CounterWork(), CounterWork()])
50
            ...     def run(self):
51
            ...         for work in self.list:
52
            ...             work.run()
53
            ...
54
            >>> flow = RootFlow()
55
            >>> flow.run()
56
            >>> assert flow.list[0].counter == 1
57

58
        Arguments:
59
            items: A sequence of LightningWork or LightningFlow.
60

61
        """
62
        super().__init__()
63
        from lightning.app.runners.backends import Backend
64

65
        self._name: t.Optional[str] = ""
66
        self._last_index = 0
67
        self._backend: t.Optional[Backend] = None
68
        for item in items:
69
            self.append(item)
70

71
    def append(self, v):
72
        from lightning.app.core import LightningFlow, LightningWork
73

74
        _set_child_name(self, v, str(self._last_index))
75
        if self._backend:
76
            if isinstance(v, LightningFlow):
77
                LightningFlow._attach_backend(v, self._backend)
78
            elif isinstance(v, LightningWork):
79
                self._backend._wrap_run_method(_LightningAppRef().get_current(), v)
80
        v._name = f"{self.name}.{self._last_index}"
81
        self._last_index += 1
82
        super().append(v)
83

84
    @property
85
    def name(self):
86
        """Returns the name of this List object."""
87
        return self._name or "root"
88

89
    @property
90
    def works(self):
91
        from lightning.app.core import LightningFlow, LightningWork
92

93
        works = [item for item in self if isinstance(item, LightningWork)]
94
        for flow in [item for item in self if isinstance(item, LightningFlow)]:
95
            for child_work in flow.works(recurse=False):
96
                works.append(child_work)
97
        return works
98

99
    @property
100
    def flows(self):
101
        from lightning.app.core import LightningFlow
102
        from lightning.app.structures import Dict as _Dict
103
        from lightning.app.structures import List as _List
104

105
        flows = {}
106
        for item in self:
107
            if isinstance(item, LightningFlow):
108
                flows[item.name] = item
109
                for child_flow in item.flows.values():
110
                    flows[child_flow.name] = child_flow
111
            if isinstance(item, (_Dict, _List)):
112
                for child_flow in item.flows.values():
113
                    flows[child_flow.name] = child_flow
114
        return flows
115

116
    @property
117
    def state(self):
118
        """Returns the state of its flows and works."""
119
        from lightning.app.core import LightningFlow, LightningWork
120

121
        works = [item for item in self if isinstance(item, LightningWork)]
122
        children = [item for item in self if isinstance(item, LightningFlow)]
123
        return {
124
            "works": {_prepare_name(w): w.state for w in works},
125
            "flows": {_prepare_name(flow): flow.state for flow in children},
126
        }
127

128
    @property
129
    def state_vars(self):
130
        from lightning.app.core import LightningFlow, LightningWork
131

132
        works = [item for item in self if isinstance(item, LightningWork)]
133
        children = [item for item in self if isinstance(item, LightningFlow)]
134
        return {
135
            "works": {_prepare_name(w): w.state_vars for w in works},
136
            "flows": {_prepare_name(flow): flow.state_vars for flow in children},
137
        }
138

139
    @property
140
    def state_with_changes(self):
141
        from lightning.app.core import LightningFlow, LightningWork
142

143
        works = [item for item in self if isinstance(item, LightningWork)]
144
        children = [item for item in self if isinstance(item, LightningFlow)]
145
        return {
146
            "works": {str(_prepare_name(w)): w.state_with_changes for w in works},
147
            "flows": {_prepare_name(flow): flow.state_with_changes for flow in children},
148
        }
149

150
    def set_state(self, state):
151
        """Method to set the state of the list and its children."""
152
        from lightning.app.core import LightningFlow, LightningWork
153

154
        works = [item for item in self if isinstance(item, LightningWork)]
155
        children = [item for item in self if isinstance(item, LightningFlow)]
156

157
        current_state_keys = {_prepare_name(w) for w in self}
158
        state_keys = set(list(state["works"].keys()) + list(state["flows"].keys()))
159

160
        if current_state_keys != state_keys:
161
            key_diff = (current_state_keys - state_keys) | (state_keys - current_state_keys)
162
            raise Exception(
163
                f"The provided state doesn't match the `List` {self.name}. Found `{key_diff}` un-matching keys"
164
            )
165

166
        for work_key, work_state in state["works"].items():
167
            for work in works:
168
                if _prepare_name(work) == work_key:
169
                    work.set_state(work_state)
170
        for child_key, child_state in state["flows"].items():
171
            for child in children:
172
                if _prepare_name(child) == child_key:
173
                    child.set_state(child_state)
174

175
    def __len__(self):
176
        """Returns the number of elements within this List."""
177
        return sum(1 for _ in self)
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.