pytorch-lightning

Форк
0
323 строки · 12.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import enum
16
import json
17
import os
18
from copy import deepcopy
19
from time import sleep
20
from typing import Any, Dict, List, Optional, Tuple, Union
21

22
from deepdiff import DeepDiff
23
from requests import Session
24
from requests.exceptions import ConnectionError
25

26
from lightning.app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT
27
from lightning.app.storage.drive import _maybe_create_drive
28
from lightning.app.utilities.app_helpers import AppStatePlugin, BaseStatePlugin, Logger
29
from lightning.app.utilities.network import LightningClient, _configure_session
30

31
logger = Logger(__name__)
32

33
# GLOBAL APP STATE
34
_LAST_STATE = None
35
_STATE = None
36

37

38
class AppStateType(enum.Enum):
39
    STREAMLIT = enum.auto()
40
    DEFAULT = enum.auto()
41

42

43
def headers_for(context: Dict[str, str]) -> Dict[str, str]:
44
    return {
45
        "X-Lightning-Session-UUID": context.get("token", ""),
46
        "X-Lightning-Session-ID": context.get("session_id", ""),
47
        "X-Lightning-Type": context.get("type", ""),
48
    }
49

50

51
class AppState:
52
    _APP_PRIVATE_KEYS: Tuple[str, ...] = (
53
        "_use_localhost",
54
        "_host",
55
        "_session_id",
56
        "_state",
57
        "_last_state",
58
        "_url",
59
        "_port",
60
        "_request_state",
61
        "_store_state",
62
        "_send_state",
63
        "_my_affiliation",
64
        "_find_state_under_affiliation",
65
        "_plugin",
66
        "_attach_plugin",
67
        "_authorized",
68
        "is_authorized",
69
        "_debug",
70
        "_session",
71
    )
72
    _MY_AFFILIATION: Tuple[str, ...] = ()
73

74
    def __init__(
75
        self,
76
        host: Optional[str] = None,
77
        port: Optional[int] = None,
78
        last_state: Optional[Dict] = None,
79
        state: Optional[Dict] = None,
80
        my_affiliation: Tuple[str, ...] = None,
81
        plugin: Optional[BaseStatePlugin] = None,
82
    ) -> None:
83
        """The AppState class enables Frontend users to interact with their application state.
84

85
        When the state isn't defined, it would be pulled from the app REST API Server.
86
        If the state gets modified by the user, the new state would be sent to the API Server.
87

88
        Arguments:
89
            host: Rest API Server current host
90
            port: Rest API Server current port
91
            last_state: The state pulled on first access.
92
            state: The state modified by the user.
93
            my_affiliation: A tuple describing the affiliation this app state represents. When storing a state dict
94
                on this AppState, this affiliation will be used to reduce the scope of the given state.
95
            plugin: A plugin to handle authorization.
96

97
        """
98
        self._use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
99
        self._host = host or ("http://127.0.0.1" if self._use_localhost else None)
100
        self._port = port or (APP_SERVER_PORT if self._use_localhost else None)
101
        self._last_state = last_state
102
        self._state = state
103
        self._session_id = "1234"
104
        self._my_affiliation = my_affiliation if my_affiliation is not None else AppState._MY_AFFILIATION
105
        self._authorized = None
106
        self._attach_plugin(plugin)
107
        self._session = self._configure_session()
108

109
    @property
110
    def _url(self) -> str:
111
        if self._host is None:
112
            app_ip = ""
113

114
            if "LIGHTNING_CLOUD_PROJECT_ID" in os.environ and "LIGHTNING_CLOUD_APP_ID" in os.environ:
115
                client = LightningClient()
116
                app_instance = client.lightningapp_instance_service_get_lightningapp_instance(
117
                    os.environ.get("LIGHTNING_CLOUD_PROJECT_ID"),
118
                    os.environ.get("LIGHTNING_CLOUD_APP_ID"),
119
                )
120
                app_ip = app_instance.status.ip_address
121

122
            # TODO: Don't hard code port 8080 here
123
            self._host = f"http://{app_ip}:8080" if app_ip else APP_SERVER_HOST
124
        return f"{self._host}:{self._port}" if self._use_localhost else self._host
125

126
    def _attach_plugin(self, plugin: Optional[BaseStatePlugin]) -> None:
127
        plugin = plugin if plugin is not None else AppStatePlugin()
128
        self._plugin = plugin
129

130
    @staticmethod
131
    def _find_state_under_affiliation(state, my_affiliation: Tuple[str, ...]) -> Dict[str, Any]:
132
        """This method is used to extract the subset of the app state associated with the given affiliation.
133

134
        For example, if the affiliation is ``("root", "subflow")``, then the returned state will be
135
        ``state["flows"]["subflow"]``.
136

137
        """
138
        children_state = state
139
        for name in my_affiliation:
140
            if name in children_state["flows"]:
141
                children_state = children_state["flows"][name]
142
            elif name in children_state["works"]:
143
                children_state = children_state["works"][name]
144
            else:
145
                raise ValueError(f"Failed to extract the state under the affiliation '{my_affiliation}'.")
146
        return children_state
147

148
    def _store_state(self, state: Dict[str, Any]) -> None:
149
        # Relying on the global variable to ensure the
150
        # deep_diff is done on the entire state.
151
        global _LAST_STATE
152
        global _STATE
153
        _LAST_STATE = deepcopy(state)
154
        _STATE = state
155
        # If the affiliation is passed, the AppState was created in a LightningFlow context.
156
        # The state should be only the one of this LightningFlow and its children.
157
        self._last_state = self._find_state_under_affiliation(_LAST_STATE, self._my_affiliation)
158
        self._state = self._find_state_under_affiliation(_STATE, self._my_affiliation)
159

160
    def send_delta(self) -> None:
161
        app_url = f"{self._url}/api/v1/delta"
162
        deep_diff = DeepDiff(_LAST_STATE, _STATE, verbose_level=2)
163
        assert self._plugin is not None
164
        # TODO: Find how to prevent the infinite loop on refresh without storing the DeepDiff
165
        if self._plugin.should_update_app(deep_diff):
166
            data = {"delta": json.loads(deep_diff.to_json())}
167
            headers = headers_for(self._plugin.get_context())
168
            try:
169
                # TODO: Send the delta directly to the REST API.
170
                response = self._session.post(app_url, json=data, headers=headers)
171
            except ConnectionError as ex:
172
                raise AttributeError("Failed to connect and send the app state. Is the app running?") from ex
173

174
            if response and response.status_code != 200:
175
                raise Exception(f"The response from the server was {response.status_code}. Your inputs were rejected.")
176

177
    def _request_state(self) -> None:
178
        if self._state is not None:
179
            return
180
        app_url = f"{self._url}/api/v1/state"
181
        headers = headers_for(self._plugin.get_context()) if self._plugin else {}
182

183
        response_json = {}
184

185
        # Sometimes the state URL can return an empty JSON when things are being set-up,
186
        # so we wait for it to be ready here.
187
        while response_json == {}:
188
            sleep(0.5)
189
            try:
190
                response = self._session.get(app_url, headers=headers, timeout=1)
191
            except ConnectionError as ex:
192
                raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from ex
193

194
            self._authorized = response.status_code
195
            if self._authorized != 200:
196
                return
197

198
            response_json = response.json()
199

200
        logger.debug(f"GET STATE {response} {response_json}")
201
        self._store_state(response_json)
202

203
    def __getattr__(self, name: str) -> Union[Any, "AppState"]:
204
        if name in self._APP_PRIVATE_KEYS:
205
            return object.__getattr__(self, name)
206

207
        # The state needs to be fetched on access if it doesn't exist.
208
        self._request_state()
209

210
        if name in self._state.get("vars", {}):
211
            value = self._state["vars"][name]
212
            if isinstance(value, dict):
213
                return _maybe_create_drive("root." + ".".join(self._my_affiliation), value)
214
            return value
215

216
        if name in self._state.get("works", {}):
217
            return AppState(
218
                self._host, self._port, last_state=self._last_state["works"][name], state=self._state["works"][name]
219
            )
220

221
        if name in self._state.get("flows", {}):
222
            return AppState(
223
                self._host,
224
                self._port,
225
                last_state=self._last_state["flows"][name],
226
                state=self._state["flows"][name],
227
            )
228

229
        if name in self._state.get("structures", {}):
230
            return AppState(
231
                self._host,
232
                self._port,
233
                last_state=self._last_state["structures"][name],
234
                state=self._state["structures"][name],
235
            )
236

237
        raise AttributeError(
238
            f"Failed to access '{name}' through `AppState`. The state provides:"
239
            f" Variables: {list(self._state['vars'].keys())},"
240
            f" Components: {list(self._state.get('flows', {}).keys()) + list(self._state.get('works', {}).keys())}",
241
        )
242

243
    def __getitem__(self, key: str):
244
        return self.__getattr__(key)
245

246
    def __setattr__(self, name: str, value: Any) -> None:
247
        if name in self._APP_PRIVATE_KEYS:
248
            object.__setattr__(self, name, value)
249
            return
250

251
        # The state needs to be fetched on access if it doesn't exist.
252
        self._request_state()
253

254
        # TODO: Find a way to aggregate deltas to avoid making
255
        # request for each attribute change.
256
        if name in self._state["vars"]:
257
            self._state["vars"][name] = value
258
            self.send_delta()
259

260
        elif name in self._state["flows"]:
261
            raise AttributeError("You shouldn't set the flows directly onto the state. Use its attributes instead.")
262

263
        elif name in self._state["works"]:
264
            raise AttributeError("You shouldn't set the works directly onto the state. Use its attributes instead.")
265

266
        else:
267
            raise AttributeError(
268
                f"Failed to access '{name}' through `AppState`. The state provides:"
269
                f" Variables: {list(self._state['vars'].keys())},"
270
                f" Components: {list(self._state['flows'].keys()) + list(self._state['works'].keys())}",
271
            )
272

273
    def __repr__(self) -> str:
274
        return str(self._state)
275

276
    def __bool__(self) -> bool:
277
        return bool(self._state)
278

279
    def __len__(self) -> int:
280
        # The state needs to be fetched on access if it doesn't exist.
281
        self._request_state()
282

283
        keys = []
284
        for component in ["flows", "works", "structures"]:
285
            keys.extend(list(self._state.get(component, {})))
286
        return len(keys)
287

288
    def items(self) -> List[Dict[str, Any]]:
289
        # The state needs to be fetched on access if it doesn't exist.
290
        self._request_state()
291

292
        items = []
293
        for component in ["flows", "works"]:
294
            state = self._state.get(component, {})
295
            last_state = self._last_state.get(component, {})
296
            for name, state_value in state.items():
297
                v = AppState(
298
                    self._host,
299
                    self._port,
300
                    last_state=last_state[name],
301
                    state=state_value,
302
                )
303
                items.append((name, v))
304

305
        structures = self._state.get("structures", {})
306
        last_structures = self._last_state.get("structures", {})
307
        if structures:
308
            for component in ["flows", "works"]:
309
                state = structures.get(component, {})
310
                last_state = last_structures.get(component, {})
311
                for name, state_value in state.items():
312
                    v = AppState(
313
                        self._host,
314
                        self._port,
315
                        last_state=last_state[name],
316
                        state=state_value,
317
                    )
318
                    items.append((name, v))
319
        return items
320

321
    @staticmethod
322
    def _configure_session() -> Session:
323
        return _configure_session()
324

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.