pytorch-lightning
582 строки · 18.6 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc16import asyncio17import builtins18import enum19import functools20import inspect21import json22import logging23import os24import sys25import threading26import time27from abc import ABC, abstractmethod28from contextlib import contextmanager29from copy import deepcopy30from dataclasses import dataclass, field31from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Type32from unittest.mock import MagicMock33
34import websockets35from deepdiff import Delta36
37import lightning.app38from lightning.app.utilities.exceptions import LightningAppStateException39from lightning.app.utilities.tree import breadth_first40
41if TYPE_CHECKING:42from lightning.app.core.app import LightningApp43from lightning.app.core.flow import LightningFlow44from lightning.app.utilities.types import Component45
46logger = logging.getLogger(__name__)47
48
49@dataclass
50class StateEntry:51"""Dataclass used to keep track the latest state shared through the app REST API."""52
53app_state: Mapping = field(default_factory=dict)54served_state: Mapping = field(default_factory=dict)55session_id: Optional[str] = None56
57
58class StateStore(ABC):59"""Base class of State store that provides simple key, value store to keep track of app state, served app state."""60
61@abstractmethod62def __init__(self):63pass64
65@abstractmethod66def add(self, k: str):67"""Creates a new empty state with input key 'k'."""68pass69
70@abstractmethod71def remove(self, k: str):72"""Deletes a state with input key 'k'."""73pass74
75@abstractmethod76def get_app_state(self, k: str) -> Mapping:77"""Returns a stored appstate for an input key 'k'."""78pass79
80@abstractmethod81def get_served_state(self, k: str) -> Mapping:82"""Returns a last served app state for an input key 'k'."""83pass84
85@abstractmethod86def get_served_session_id(self, k: str) -> str:87"""Returns session id for state of a key 'k'."""88pass89
90@abstractmethod91def set_app_state(self, k: str, v: Mapping):92"""Sets the app state for state of a key 'k'."""93pass94
95@abstractmethod96def set_served_state(self, k: str, v: Mapping):97"""Sets the served state for state of a key 'k'."""98pass99
100@abstractmethod101def set_served_session_id(self, k: str, v: str):102"""Sets the session id for state of a key 'k'."""103pass104
105
106class InMemoryStateStore(StateStore):107"""In memory simple store to keep track of state through the app REST API."""108
109def __init__(self):110self.store = {}111self.counter = 0112
113def add(self, k):114self.store[k] = StateEntry()115
116def remove(self, k):117del self.store[k]118
119def get_app_state(self, k):120return self.store[k].app_state121
122def get_served_state(self, k):123return self.store[k].served_state124
125def get_served_session_id(self, k):126return self.store[k].session_id127
128def set_app_state(self, k, v):129state_size = sys.getsizeof(v)130if state_size > lightning.app.core.constants.APP_STATE_MAX_SIZE_BYTES:131raise LightningAppStateException(132f"App state size is {state_size} bytes, which is larger than the recommended size "133f"of {lightning.app.core.constants.APP_STATE_MAX_SIZE_BYTES}. Please investigate this."134)135self.store[k].app_state = deepcopy(v)136self.counter += 1137
138def set_served_state(self, k, v):139self.store[k].served_state = deepcopy(v)140
141def set_served_session_id(self, k, v):142self.store[k].session_id = v143
144
145class _LightningAppRef:146_app_instance: Optional["LightningApp"] = None147
148@classmethod149def connect(cls, app_instance: "LightningApp") -> None:150cls._app_instance = app_instance151
152@classmethod153def get_current(cls) -> Optional["LightningApp"]:154if cls._app_instance:155return cls._app_instance156return None157
158
159def affiliation(component: "Component") -> Tuple[str, ...]:160"""Returns the affiliation of a component."""161if component.name in ("root", ""):162return ()163return tuple(component.name.split(".")[1:])164
165
166class AppStateType(str, enum.Enum):167STREAMLIT = "STREAMLIT"168DEFAULT = "DEFAULT"169
170
171class BaseStatePlugin(abc.ABC):172def __init__(self):173self.authorized = None174
175@abc.abstractmethod176def should_update_app(self, deep_diff):177pass178
179@abc.abstractmethod180def get_context(self):181pass182
183@abc.abstractmethod184def render_non_authorized(self):185pass186
187
188class AppStatePlugin(BaseStatePlugin):189def should_update_app(self, deep_diff):190return deep_diff191
192def get_context(self):193return {"type": AppStateType.DEFAULT.value}194
195def render_non_authorized(self):196pass197
198
199def target_fn():200try:201# streamlit >= 1.14.0202from streamlit import runtime203
204get_instance = runtime.get_instance205exists = runtime.exists()206except ImportError:207# Older versions208from streamlit.server.server import Server209
210get_instance = Server.get_current211exists = bool(Server._singleton)212
213async def update_fn():214runtime_instance = get_instance()215sessions = list(runtime_instance._session_info_by_id.values())216url = (217"localhost:8080"218if "LIGHTNING_APP_STATE_URL" in os.environ219else f"localhost:{lightning.app.core.constants.APP_SERVER_PORT}"220)221ws_url = f"ws://{url}/api/v1/ws"222last_updated = time.time()223async with websockets.connect(ws_url) as websocket:224while True:225try:226_ = await websocket.recv()227
228while (time.time() - last_updated) < 1:229time.sleep(0.1)230for session in sessions:231session = session.session232session.request_rerun(session._client_state)233last_updated = time.time()234except websockets.exceptions.ConnectionClosedOK:235# The websocket is not enabled236break237
238if exists:239asyncio.run(update_fn())240
241
242class StreamLitStatePlugin(BaseStatePlugin):243def __init__(self):244super().__init__()245import streamlit as st246
247if hasattr(st, "session_state") and "websocket_thread" not in st.session_state:248thread = threading.Thread(target=target_fn)249st.session_state.websocket_thread = thread250thread.setDaemon(True)251thread.start()252
253def should_update_app(self, deep_diff):254return deep_diff255
256def get_context(self):257return {"type": AppStateType.DEFAULT.value}258
259def render_non_authorized(self):260pass261
262
263def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None) -> bool:264if instance is None:265return False266if parent is None:267if isinstance(instance, lightning.app.LightningFlow):268parent = lightning.app.LightningFlow269elif isinstance(instance, lightning.app.LightningWork):270parent = lightning.app.LightningWork271if parent is None:272raise ValueError("Expected a parent")273from lightning_utilities.core.overrides import is_overridden274
275return is_overridden(method_name, instance, parent)276
277
278def _is_json_serializable(x: Any) -> bool:279"""Test whether a variable can be encoded as json."""280if type(x) in lightning.app.core.constants.SUPPORTED_PRIMITIVE_TYPES:281# shortcut for primitive types that are not containers282return True283try:284json.dumps(x, cls=LightningJSONEncoder)285return True286except (TypeError, OverflowError):287# OverflowError is raised if number is too large to encode288return False289
290
291def _set_child_name(component: "Component", child: "Component", new_name: str) -> str:292"""Computes and sets the name of a child given the parent, and returns the name."""293child_name = f"{component.name}.{new_name}"294child._name = child_name295
296# the name changed, so recursively update the names of the children of this child297if isinstance(child, lightning.app.core.LightningFlow):298for n in child._flows:299c = getattr(child, n)300_set_child_name(child, c, n)301for n in child._works:302c = getattr(child, n)303_set_child_name(child, c, n)304for n in child._structures:305s = getattr(child, n)306_set_child_name(child, s, n)307if isinstance(child, lightning.app.structures.Dict):308for n, c in child.items():309_set_child_name(child, c, n)310if isinstance(child, lightning.app.structures.List):311for c in child:312_set_child_name(child, c, c.name.split(".")[-1])313
314return child_name315
316
317def _delta_to_app_state_delta(root: "LightningFlow", component: "Component", delta: Delta) -> Delta:318delta_dict = delta.to_dict()319for changed in delta_dict.values():320for delta_key in changed.copy():321val = changed[delta_key]322
323new_prefix = "root"324for p, c in _walk_to_component(root, component):325if isinstance(c, lightning.app.core.LightningWork):326new_prefix += "['works']"327
328if isinstance(c, lightning.app.core.LightningFlow):329new_prefix += "['flows']"330
331if isinstance(c, (lightning.app.structures.Dict, lightning.app.structures.List)):332new_prefix += "['structures']"333
334c_n = c.name.split(".")[-1]335new_prefix += f"['{c_n}']"336
337delta_key_without_root = delta_key[4:] # the first 4 chars are the word 'root', strip it338new_key = new_prefix + delta_key_without_root339if new_key != delta_key:340changed[new_key] = val341del changed[delta_key]342
343return Delta(delta_dict)344
345
346def _walk_to_component(347root: "LightningFlow",348component: "Component",349) -> Generator[Tuple["Component", "Component"], None, None]:350"""Returns a generator that runs through the tree starting from the root down to the given component.351
352At each node, yields parent and child as a tuple.
353
354"""
355from lightning.app.structures import Dict, List356
357name_parts = component.name.split(".")[1:] # exclude 'root' from the name358parent = root359for n in name_parts:360if isinstance(parent, (Dict, List)):361child = parent[n] if isinstance(parent, Dict) else parent[int(n)]362else:363child = getattr(parent, n)364yield parent, child365parent = child366
367
368def _collect_child_process_pids(pid: int) -> List[int]:369"""Function to return the list of child process pid's of a process."""370processes = os.popen("ps -ej | grep -i 'python' | grep -v 'grep' | awk '{ print $2,$3 }'").read()371processes = [p.split(" ") for p in processes.split("\n")[:-1]]372return [int(child) for child, parent in processes if parent == str(pid) and child != str(pid)]373
374
375def _print_to_logger_info(*args: Any, **kwargs: Any):376# TODO Find a better way to re-direct print to loggers.377lightning.app._logger.info(" ".join([str(v) for v in args]))378
379
380def convert_print_to_logger_info(func: Callable) -> Callable:381"""This function is used to transform any print into logger.info calls, so it gets tracked in the cloud."""382
383@functools.wraps(func)384def wrapper(*args: Any, **kwargs: Any) -> Any:385original_print = __builtins__["print"]386__builtins__["print"] = _print_to_logger_info387res = func(*args, **kwargs)388__builtins__["print"] = original_print389return res390
391return wrapper392
393
394def pretty_state(state: Dict) -> Dict:395"""Utility to prettify the state by removing hidden attributes."""396new_state = {}397for k, v in state["vars"].items():398if not k.startswith("_"):399if "vars" not in new_state:400new_state["vars"] = {}401new_state["vars"][k] = v402if "flows" in state:403for k, v in state["flows"].items():404if "flows" not in new_state:405new_state["flows"] = {}406new_state["flows"][k] = pretty_state(state["flows"][k])407if "works" in state:408for k, v in state["works"].items():409if "works" not in new_state:410new_state["works"] = {}411new_state["works"][k] = pretty_state(state["works"][k])412return new_state413
414
415class LightningJSONEncoder(json.JSONEncoder):416def default(self, obj: Any) -> Any:417if callable(getattr(obj, "__json__", None)):418return obj.__json__()419return json.JSONEncoder.default(self, obj)420
421
422class Logger:423"""This class is used to improve the debugging experience."""424
425def __init__(self, name: str):426self.logger = logging.getLogger(name)427self.level = None428
429def info(self, msg, *args: Any, **kwargs: Any):430self.logger.info(msg, *args, **kwargs)431
432def warn(self, msg, *args: Any, **kwargs: Any):433self._set_level()434self.logger.warn(msg, *args, **kwargs)435
436def debug(self, msg, *args: Any, **kwargs: Any):437self._set_level()438self.logger.debug(msg, *args, **kwargs)439
440def error(self, msg, *args: Any, **kwargs: Any):441self._set_level()442self.logger.error(msg, *args, **kwargs)443
444def _set_level(self):445"""Lazily set the level once set by the users."""446# Set on the first from either log, warn, debug or error call.447if self.level is None:448self.level = logging.DEBUG if bool(int(os.getenv("LIGHTNING_DEBUG", "0"))) else logging.INFO449self.logger.setLevel(self.level)450
451
452def _state_dict(flow: "LightningFlow"):453state = {}454flows = [flow] + list(flow.flows.values())455for f in flows:456state[f.name] = f.state_dict()457for w in flow.works():458state[w.name] = w.state459return state460
461
462def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict: bool = True) -> None:463"""This function is used to reload the state assuming dynamic components creation.464
465When a component isn't found but its state exists, its state is passed up to its closest existing parent.
466
467Arguments:
468root_flow: The flow at the top of the component tree.
469state: The collected state dict.
470strict: Whether to validate all components have been re-created.
471
472"""
473# 1: Reload the state of the existing works474for w in root_flow.works():475w.set_state(state.pop(w.name))476
477# 2: Collect the existing flows478flows = [root_flow] + list(root_flow.flows.values())479flow_map = {f.name: f for f in flows}480
481# 3: Find the state of the all dynamic components482dynamic_components = {k: v for k, v in state.items() if k not in flow_map}483
484# 4: Propagate the state of the dynamic components to their closest parents485dynamic_children_state = {}486for name, component_state in dynamic_components.items():487affiliation = name.split(".")488for idx in range(0, len(affiliation)):489parent_name = ".".join(affiliation[:-idx])490has_matched = False491for flow_name, flow in flow_map.items():492if flow_name == parent_name:493if flow_name not in dynamic_children_state:494dynamic_children_state[flow_name] = {}495
496dynamic_children_state[flow_name].update({name.replace(parent_name + ".", ""): component_state})497has_matched = True498break499if has_matched:500break501
502# 5: Reload the flow states503for flow_name, flow in flow_map.items():504flow.load_state_dict(state.pop(flow_name), dynamic_children_state.get(flow_name, {}), strict=strict)505
506# 6: Verify all dynamic components has been re-created.507if strict:508components_names = (509[root_flow.name] + [f.name for f in root_flow.flows.values()] + [w.name for w in root_flow.works()]510)511for component_name in dynamic_components:512if component_name not in components_names:513raise Exception(f"The component {component_name} was re-created during state reloading.")514
515
516class _MagicMockJsonSerializable(MagicMock):517@staticmethod518def __json__():519return "{}"520
521
522def _mock_import(*args, original_fn=None):523try:524return original_fn(*args)525except Exception:526return _MagicMockJsonSerializable()527
528
529@contextmanager
530def _mock_missing_imports():531original_fn = builtins.__import__532builtins.__import__ = functools.partial(_mock_import, original_fn=original_fn)533try:534yield535finally:536builtins.__import__ = original_fn537
538
539def is_static_method(klass_or_instance, attr) -> bool:540return isinstance(inspect.getattr_static(klass_or_instance, attr), staticmethod)541
542
543def _lightning_dispatched() -> bool:544return bool(int(os.getenv("LIGHTNING_DISPATCHED", 0)))545
546
547def _using_debugger() -> bool:548"""This method is used to detect whether the app is run with a debugger attached."""549if "LIGHTNING_DETECTED_DEBUGGER" in os.environ:550return True551
552# Collect the information about the process.553parent_process = os.popen(f"ps -ax | grep -i {os.getpid()} | grep -v grep").read()554
555# Detect whether VSCode or PyCharm debugger are used556use_debugger = "debugpy" in parent_process or "pydev" in parent_process557
558# Store the result to avoid multiple popen calls.559if use_debugger:560os.environ["LIGHTNING_DETECTED_DEBUGGER"] = "1"561return use_debugger562
563
564def _should_dispatch_app() -> bool:565return (566not _lightning_dispatched()567and "LIGHTNING_APP_STATE_URL" not in os.environ568# Keep last to avoid running it if already dispatched569and _using_debugger()570)571
572
573def _is_headless(app: "LightningApp") -> bool:574"""Utility which returns True if the given App has no ``Frontend`` objects or URLs exposed through575``configure_layout``."""
576if app.frontends:577return False578for component in breadth_first(app.root, types=(lightning.app.LightningFlow,)):579for entry in component._layout:580if "target" in entry:581return False582return True583