pytorch-lightning

Форк
0
160 строк · 6.1 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import typing as t
16

17
from lightning.app.utilities.app_helpers import _LightningAppRef, _set_child_name
18

19
T = t.TypeVar("T")
20

21
if t.TYPE_CHECKING:
22
    from lightning.app.utilities.types import Component
23

24

25
def _prepare_name(component: "Component") -> str:
26
    return str(component.name.split(".")[-1])
27

28

29
# TODO: add support and tests for dict operations (insertion, update, etc.)
30
class Dict(t.Dict[str, T]):
31
    def __init__(self, **kwargs: T):
32
        """The Dict Object is used to represents dict collection of :class:`~lightning.app.core.work.LightningWork` or
33
        :class:`~lightning.app.core.flow.LightningFlow`.
34

35
        Example:
36

37
            >>> from lightning.app import LightningFlow, LightningWork
38
            >>> from lightning.app.structures import Dict
39
            >>> class CounterWork(LightningWork):
40
            ...     def __init__(self):
41
            ...         super().__init__()
42
            ...         self.counter = 0
43
            ...     def run(self):
44
            ...         self.counter += 1
45
            ...
46
            >>> class RootFlow(LightningFlow):
47
            ...     def __init__(self):
48
            ...         super().__init__()
49
            ...         self.dict = Dict(**{"work_0": CounterWork(), "work_1": CounterWork()})
50
            ...     def run(self):
51
            ...         for work_name, work in self.dict.items():
52
            ...             work.run()
53
            ...
54
            >>> flow = RootFlow()
55
            >>> flow.run()
56
            >>> assert flow.dict["work_0"].counter == 1
57

58
        Arguments:
59
            items: A sequence of LightningWork or LightningFlow.
60

61
        """
62
        super().__init__(**kwargs)
63
        from lightning.app.runners.backends import Backend
64

65
        self._name: t.Optional[str] = ""
66
        self._backend: t.Optional[Backend] = None
67
        for k, v in kwargs.items():
68
            if "." in k:
69
                raise Exception(f"The provided name {k} contains . which is forbidden.")
70
            _set_child_name(self, v, k)
71

72
    def __setitem__(self, k, v):
73
        from lightning.app.core import LightningFlow, LightningWork
74

75
        if not isinstance(k, str):
76
            raise Exception("The provided key should be an string")
77

78
        if isinstance(k, str) and "." in k:
79
            raise Exception(f"The provided name {k} contains . which is forbidden.")
80

81
        _set_child_name(self, v, k)
82
        if self._backend:
83
            if isinstance(v, LightningFlow):
84
                LightningFlow._attach_backend(v, self._backend)
85
            elif isinstance(v, LightningWork):
86
                self._backend._wrap_run_method(_LightningAppRef().get_current(), v)
87
        v._name = f"{self.name}.{k}"
88
        super().__setitem__(k, v)
89

90
    @property
91
    def works(self):
92
        from lightning.app.core import LightningFlow, LightningWork
93

94
        works = [item for item in self.values() if isinstance(item, LightningWork)]
95
        for flow in [item for item in self.values() if isinstance(item, LightningFlow)]:
96
            for child_work in flow.works(recurse=False):
97
                works.append(child_work)
98
        return works
99

100
    @property
101
    def flows(self):
102
        from lightning.app.core.flow import LightningFlow
103
        from lightning.app.structures import Dict as _Dict
104
        from lightning.app.structures import List as _List
105

106
        flows = {}
107
        for item in self.values():
108
            if isinstance(item, LightningFlow):
109
                flows[item.name] = item
110
                for child_flow in item.flows.values():
111
                    flows[child_flow.name] = child_flow
112
            if isinstance(item, (_Dict, _List)):
113
                for child_flow in item.flows.values():
114
                    flows[child_flow.name] = child_flow
115
        return flows
116

117
    @property
118
    def name(self):
119
        return self._name or "root"
120

121
    @property
122
    def state(self):
123
        """Returns the state of its flows and works."""
124
        from lightning.app.core import LightningFlow, LightningWork
125

126
        return {
127
            "works": {key: item.state for key, item in self.items() if isinstance(item, LightningWork)},
128
            "flows": {key: item.state for key, item in self.items() if isinstance(item, LightningFlow)},
129
        }
130

131
    @property
132
    def state_vars(self):
133
        from lightning.app.core import LightningFlow, LightningWork
134

135
        return {
136
            "works": {key: item.state_vars for key, item in self.items() if isinstance(item, LightningWork)},
137
            "flows": {key: item.state_vars for key, item in self.items() if isinstance(item, LightningFlow)},
138
        }
139

140
    @property
141
    def state_with_changes(self):
142
        from lightning.app.core import LightningFlow, LightningWork
143

144
        return {
145
            "works": {key: item.state_with_changes for key, item in self.items() if isinstance(item, LightningWork)},
146
            "flows": {key: item.state_with_changes for key, item in self.items() if isinstance(item, LightningFlow)},
147
        }
148

149
    def set_state(self, state):
150
        state_keys = set(list(state["works"].keys()) + list(state["flows"].keys()))
151
        current_state_keys = set(self.keys())
152
        if current_state_keys != state_keys:
153
            key_diff = (current_state_keys - state_keys) | (state_keys - current_state_keys)
154
            raise Exception(
155
                f"The provided state doesn't match the `Dict` {self.name}. Found `{key_diff}` un-matching keys"
156
            )
157
        for work_key, work_state in state["works"].items():
158
            self[work_key].set_state(work_state)
159
        for child_key, child_state in state["flows"].items():
160
            self[child_key].set_state(child_state)
161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.