pytorch-lightning
323 строки · 12.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import enum16import json17import os18from copy import deepcopy19from time import sleep20from typing import Any, Dict, List, Optional, Tuple, Union21
22from deepdiff import DeepDiff23from requests import Session24from requests.exceptions import ConnectionError25
26from lightning.app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT27from lightning.app.storage.drive import _maybe_create_drive28from lightning.app.utilities.app_helpers import AppStatePlugin, BaseStatePlugin, Logger29from lightning.app.utilities.network import LightningClient, _configure_session30
31logger = Logger(__name__)32
33# GLOBAL APP STATE
34_LAST_STATE = None35_STATE = None36
37
38class AppStateType(enum.Enum):39STREAMLIT = enum.auto()40DEFAULT = enum.auto()41
42
43def headers_for(context: Dict[str, str]) -> Dict[str, str]:44return {45"X-Lightning-Session-UUID": context.get("token", ""),46"X-Lightning-Session-ID": context.get("session_id", ""),47"X-Lightning-Type": context.get("type", ""),48}49
50
51class AppState:52_APP_PRIVATE_KEYS: Tuple[str, ...] = (53"_use_localhost",54"_host",55"_session_id",56"_state",57"_last_state",58"_url",59"_port",60"_request_state",61"_store_state",62"_send_state",63"_my_affiliation",64"_find_state_under_affiliation",65"_plugin",66"_attach_plugin",67"_authorized",68"is_authorized",69"_debug",70"_session",71)72_MY_AFFILIATION: Tuple[str, ...] = ()73
74def __init__(75self,76host: Optional[str] = None,77port: Optional[int] = None,78last_state: Optional[Dict] = None,79state: Optional[Dict] = None,80my_affiliation: Tuple[str, ...] = None,81plugin: Optional[BaseStatePlugin] = None,82) -> None:83"""The AppState class enables Frontend users to interact with their application state.84
85When the state isn't defined, it would be pulled from the app REST API Server.
86If the state gets modified by the user, the new state would be sent to the API Server.
87
88Arguments:
89host: Rest API Server current host
90port: Rest API Server current port
91last_state: The state pulled on first access.
92state: The state modified by the user.
93my_affiliation: A tuple describing the affiliation this app state represents. When storing a state dict
94on this AppState, this affiliation will be used to reduce the scope of the given state.
95plugin: A plugin to handle authorization.
96
97"""
98self._use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ99self._host = host or ("http://127.0.0.1" if self._use_localhost else None)100self._port = port or (APP_SERVER_PORT if self._use_localhost else None)101self._last_state = last_state102self._state = state103self._session_id = "1234"104self._my_affiliation = my_affiliation if my_affiliation is not None else AppState._MY_AFFILIATION105self._authorized = None106self._attach_plugin(plugin)107self._session = self._configure_session()108
109@property110def _url(self) -> str:111if self._host is None:112app_ip = ""113
114if "LIGHTNING_CLOUD_PROJECT_ID" in os.environ and "LIGHTNING_CLOUD_APP_ID" in os.environ:115client = LightningClient()116app_instance = client.lightningapp_instance_service_get_lightningapp_instance(117os.environ.get("LIGHTNING_CLOUD_PROJECT_ID"),118os.environ.get("LIGHTNING_CLOUD_APP_ID"),119)120app_ip = app_instance.status.ip_address121
122# TODO: Don't hard code port 8080 here123self._host = f"http://{app_ip}:8080" if app_ip else APP_SERVER_HOST124return f"{self._host}:{self._port}" if self._use_localhost else self._host125
126def _attach_plugin(self, plugin: Optional[BaseStatePlugin]) -> None:127plugin = plugin if plugin is not None else AppStatePlugin()128self._plugin = plugin129
130@staticmethod131def _find_state_under_affiliation(state, my_affiliation: Tuple[str, ...]) -> Dict[str, Any]:132"""This method is used to extract the subset of the app state associated with the given affiliation.133
134For example, if the affiliation is ``("root", "subflow")``, then the returned state will be
135``state["flows"]["subflow"]``.
136
137"""
138children_state = state139for name in my_affiliation:140if name in children_state["flows"]:141children_state = children_state["flows"][name]142elif name in children_state["works"]:143children_state = children_state["works"][name]144else:145raise ValueError(f"Failed to extract the state under the affiliation '{my_affiliation}'.")146return children_state147
148def _store_state(self, state: Dict[str, Any]) -> None:149# Relying on the global variable to ensure the150# deep_diff is done on the entire state.151global _LAST_STATE152global _STATE153_LAST_STATE = deepcopy(state)154_STATE = state155# If the affiliation is passed, the AppState was created in a LightningFlow context.156# The state should be only the one of this LightningFlow and its children.157self._last_state = self._find_state_under_affiliation(_LAST_STATE, self._my_affiliation)158self._state = self._find_state_under_affiliation(_STATE, self._my_affiliation)159
160def send_delta(self) -> None:161app_url = f"{self._url}/api/v1/delta"162deep_diff = DeepDiff(_LAST_STATE, _STATE, verbose_level=2)163assert self._plugin is not None164# TODO: Find how to prevent the infinite loop on refresh without storing the DeepDiff165if self._plugin.should_update_app(deep_diff):166data = {"delta": json.loads(deep_diff.to_json())}167headers = headers_for(self._plugin.get_context())168try:169# TODO: Send the delta directly to the REST API.170response = self._session.post(app_url, json=data, headers=headers)171except ConnectionError as ex:172raise AttributeError("Failed to connect and send the app state. Is the app running?") from ex173
174if response and response.status_code != 200:175raise Exception(f"The response from the server was {response.status_code}. Your inputs were rejected.")176
177def _request_state(self) -> None:178if self._state is not None:179return180app_url = f"{self._url}/api/v1/state"181headers = headers_for(self._plugin.get_context()) if self._plugin else {}182
183response_json = {}184
185# Sometimes the state URL can return an empty JSON when things are being set-up,186# so we wait for it to be ready here.187while response_json == {}:188sleep(0.5)189try:190response = self._session.get(app_url, headers=headers, timeout=1)191except ConnectionError as ex:192raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from ex193
194self._authorized = response.status_code195if self._authorized != 200:196return197
198response_json = response.json()199
200logger.debug(f"GET STATE {response} {response_json}")201self._store_state(response_json)202
203def __getattr__(self, name: str) -> Union[Any, "AppState"]:204if name in self._APP_PRIVATE_KEYS:205return object.__getattr__(self, name)206
207# The state needs to be fetched on access if it doesn't exist.208self._request_state()209
210if name in self._state.get("vars", {}):211value = self._state["vars"][name]212if isinstance(value, dict):213return _maybe_create_drive("root." + ".".join(self._my_affiliation), value)214return value215
216if name in self._state.get("works", {}):217return AppState(218self._host, self._port, last_state=self._last_state["works"][name], state=self._state["works"][name]219)220
221if name in self._state.get("flows", {}):222return AppState(223self._host,224self._port,225last_state=self._last_state["flows"][name],226state=self._state["flows"][name],227)228
229if name in self._state.get("structures", {}):230return AppState(231self._host,232self._port,233last_state=self._last_state["structures"][name],234state=self._state["structures"][name],235)236
237raise AttributeError(238f"Failed to access '{name}' through `AppState`. The state provides:"239f" Variables: {list(self._state['vars'].keys())},"240f" Components: {list(self._state.get('flows', {}).keys()) + list(self._state.get('works', {}).keys())}",241)242
243def __getitem__(self, key: str):244return self.__getattr__(key)245
246def __setattr__(self, name: str, value: Any) -> None:247if name in self._APP_PRIVATE_KEYS:248object.__setattr__(self, name, value)249return250
251# The state needs to be fetched on access if it doesn't exist.252self._request_state()253
254# TODO: Find a way to aggregate deltas to avoid making255# request for each attribute change.256if name in self._state["vars"]:257self._state["vars"][name] = value258self.send_delta()259
260elif name in self._state["flows"]:261raise AttributeError("You shouldn't set the flows directly onto the state. Use its attributes instead.")262
263elif name in self._state["works"]:264raise AttributeError("You shouldn't set the works directly onto the state. Use its attributes instead.")265
266else:267raise AttributeError(268f"Failed to access '{name}' through `AppState`. The state provides:"269f" Variables: {list(self._state['vars'].keys())},"270f" Components: {list(self._state['flows'].keys()) + list(self._state['works'].keys())}",271)272
273def __repr__(self) -> str:274return str(self._state)275
276def __bool__(self) -> bool:277return bool(self._state)278
279def __len__(self) -> int:280# The state needs to be fetched on access if it doesn't exist.281self._request_state()282
283keys = []284for component in ["flows", "works", "structures"]:285keys.extend(list(self._state.get(component, {})))286return len(keys)287
288def items(self) -> List[Dict[str, Any]]:289# The state needs to be fetched on access if it doesn't exist.290self._request_state()291
292items = []293for component in ["flows", "works"]:294state = self._state.get(component, {})295last_state = self._last_state.get(component, {})296for name, state_value in state.items():297v = AppState(298self._host,299self._port,300last_state=last_state[name],301state=state_value,302)303items.append((name, v))304
305structures = self._state.get("structures", {})306last_structures = self._last_state.get("structures", {})307if structures:308for component in ["flows", "works"]:309state = structures.get(component, {})310last_state = last_structures.get(component, {})311for name, state_value in state.items():312v = AppState(313self._host,314self._port,315last_state=last_state[name],316state=state_value,317)318items.append((name, v))319return items320
321@staticmethod322def _configure_session() -> Session:323return _configure_session()324