pytorch-lightning
737 строк · 29.2 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import logging16import os17import pickle18import queue19import threading20import warnings21from copy import deepcopy22from time import time23from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union24
25from deepdiff import DeepDiff, Delta26from lightning_utilities.core.apply_func import apply_to_collection27
28import lightning.app29from lightning.app import _console30from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest31from lightning.app.core.constants import (32BATCH_DELTA_COUNT,33DEBUG_ENABLED,34FLOW_DURATION_SAMPLES,35FLOW_DURATION_THRESHOLD,36FRONTEND_DIR,37STATE_ACCUMULATE_WAIT,38)
39from lightning.app.core.queues import BaseQueue40from lightning.app.core.work import LightningWork41from lightning.app.frontend import Frontend42from lightning.app.storage import Drive, Path, Payload43from lightning.app.storage.path import _storage_root_dir44from lightning.app.utilities import frontend45from lightning.app.utilities.app_helpers import (46Logger,47_delta_to_app_state_delta,48_LightningAppRef,49_should_dispatch_app,50)
51from lightning.app.utilities.app_status import AppStatus52from lightning.app.utilities.commands.base import _process_requests53from lightning.app.utilities.component import _convert_paths_after_init, _validate_root_flow54from lightning.app.utilities.enum import AppStage, CacheCallsKeys55from lightning.app.utilities.exceptions import CacheMissException, ExitAppException, LightningFlowException56from lightning.app.utilities.layout import _collect_layout57from lightning.app.utilities.proxies import ComponentDelta58from lightning.app.utilities.scheduler import SchedulerThread59from lightning.app.utilities.tree import breadth_first60from lightning.app.utilities.warnings import LightningFlowWarning61
62if TYPE_CHECKING:63from lightning.app.core.flow import LightningFlow64from lightning.app.runners.backends.backend import Backend, WorkManager65from lightning.app.runners.runtime import Runtime66from lightning.app.utilities.packaging.cloud_compute import CloudCompute67
68
69logger = Logger(__name__)70
71
72class LightningApp:73def __init__(74self,75root: Union["LightningFlow", LightningWork],76flow_cloud_compute: Optional["CloudCompute"] = None,77log_level: str = "info",78info: Optional[frontend.AppInfo] = None,79root_path: str = "",80) -> None:81"""The Lightning App, or App in short runs a tree of one or more components that interact to create end-to-end82applications. There are two kinds of components: :class:`~lightning.app.core.flow.LightningFlow` and
83:class:`~lightning.app.core.work.LightningWork`. This modular design enables you to reuse components created by
84other users.
85
86The Lightning App alternatively run an event loop triggered by delta changes sent from
87either :class:`~lightning.app.core.work.LightningWork` or from the Lightning UI.
88Once deltas are received, the Lightning App runs
89the :class:`~lightning.app.core.flow.LightningFlow` provided.
90
91Arguments:
92root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested
93components, running infinitely. It must define a `run()` method that the app can call.
94flow_cloud_compute: The default Cloud Compute used for flow, Rest API and frontend's.
95log_level: The log level for the app, one of [`info`, `debug`].
96This can be helpful when reporting bugs on Lightning repo.
97info: Provide additional info about the app which will be used to update html title,
98description and image meta tags and specify any additional tags as list of html strings.
99root_path: Set this to `/path` if you want to run your app behind a proxy at `/path` leave empty for "/".
100For instance, if you want to run your app at `https://customdomain.com/myapp`,
101set `root_path` to `/myapp`.
102You can learn more about proxy `here <https://www.fortinet.com/resources/cyberglossary/proxy-server>`_.
103
104"""
105
106self.root_path = root_path # when running behind a proxy107self.info = info108
109from lightning.app.core.flow import _RootFlow110
111if isinstance(root, LightningWork):112root = _RootFlow(root)113
114_validate_root_flow(root)115self._root = root116self.flow_cloud_compute = flow_cloud_compute or lightning.app.CloudCompute(name="flow-lite")117
118# queues definition.119self.delta_queue: Optional[BaseQueue] = None120self.readiness_queue: Optional[BaseQueue] = None121self.api_response_queue: Optional[BaseQueue] = None122self.api_publish_state_queue: Optional[BaseQueue] = None123self.api_delta_queue: Optional[BaseQueue] = None124self.error_queue: Optional[BaseQueue] = None125self.request_queues: Optional[Dict[str, BaseQueue]] = None126self.response_queues: Optional[Dict[str, BaseQueue]] = None127self.copy_request_queues: Optional[Dict[str, BaseQueue]] = None128self.copy_response_queues: Optional[Dict[str, BaseQueue]] = None129self.caller_queues: Optional[Dict[str, BaseQueue]] = None130self.flow_to_work_delta_queues: Optional[Dict[str, BaseQueue]] = None131self.work_queues: Optional[Dict[str, BaseQueue]] = None132self.commands: Optional[List] = None133
134self.should_publish_changes_to_api = False135self.component_affiliation = None136self.backend: Optional["Backend"] = None137_LightningAppRef.connect(self)138self.processes: Dict[str, "WorkManager"] = {}139self.frontends: Dict[str, Frontend] = {}140self.stage = AppStage.RUNNING141self._has_updated: bool = True142self._schedules: Dict[str, Dict] = {}143self.threads: List[threading.Thread] = []144self.exception = None145self.collect_changes: bool = True146
147self.status: Optional[AppStatus] = None148# TODO: Enable ready locally for opening the UI.149self.ready = False150
151# NOTE: Checkpointing is disabled by default for the time being. We152# will enable it when resuming from full checkpoint is supported. Also,153# we will need to revisit the logic at _should_snapshot, since right now154# we are writing checkpoints too often, and this is expensive.155self.checkpointing: bool = False156
157self._update_layout()158self._update_status()159
160self.is_headless: Optional[bool] = None161
162self._original_state: Optional[dict] = None163self._last_state: dict = self.state164self.state_accumulate_wait = STATE_ACCUMULATE_WAIT165
166self._last_run_time: float = 0.0167self._run_times: list = []168
169# Path attributes can't get properly attached during the initialization, because the full name170# is only available after all Flows and Works have been instantiated.171_convert_paths_after_init(self.root) # type: ignore[arg-type]172
173if log_level not in ("debug", "info"):174raise Exception(f"Log Level should be in ['debug', 'info']. Found {log_level}")175
176# Lazily enable debugging.177if log_level == "debug" or DEBUG_ENABLED:178if not DEBUG_ENABLED:179os.environ["LIGHTNING_DEBUG"] = "2"180_console.setLevel(logging.DEBUG)181
182logger.debug(f"ENV: {os.environ}")183
184if _should_dispatch_app():185os.environ["LIGHTNING_DISPATCHED"] = "1"186from lightning.app.runners import MultiProcessRuntime187
188MultiProcessRuntime(self).dispatch()189
190def _update_index_file(self) -> None:191# update index.html,192# this should happen once for all apps before the ui server starts running.193frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path)194
195def get_component_by_name(self, component_name: str) -> Union["LightningFlow", LightningWork]:196"""Returns the instance corresponding to the given component name."""197from lightning.app.structures import Dict as LightningDict198from lightning.app.structures import List as LightningList199from lightning.app.utilities.types import ComponentTuple200
201if component_name == "root":202return self.root203if not component_name.startswith("root."):204raise ValueError(f"Invalid component name {component_name}. Name must start with 'root'")205
206current = self.root207for child_name in component_name.split(".")[1:]:208if isinstance(current, LightningDict):209child = current[child_name]210elif isinstance(current, LightningList):211child = current[int(child_name)]212else:213child = getattr(current, child_name, None)214if not isinstance(child, ComponentTuple):215raise AttributeError(f"Component '{current.name}' has no child component with name '{child_name}'.")216current = child # type: ignore[assignment]217return current218
219def _reset_original_state(self) -> None:220assert self._original_state is not None221self.set_state(self._original_state)222
223@property224def root(self) -> Union["LightningFlow", LightningWork]:225"""Returns the root component of the application."""226return self._root227
228@property229def state(self) -> dict:230"""Return the current state of the application."""231state = self.root.state232state["app_state"] = {"stage": self.stage.value}233return state234
235@property236def state_vars(self) -> dict:237"""Return the current state restricted to the user defined variables of the application."""238state_vars = self.root.state_vars239state_vars["app_state"] = {"stage": self.stage.value}240return state_vars241
242@property243def state_with_changes(self) -> dict:244"""Return the current state with the new changes of the application."""245state_with_changes = self.root.state_with_changes246state_with_changes["app_state"] = {"stage": self.stage.value}247return state_with_changes248
249def set_state(self, state: dict) -> None:250"""Method to set a new app state set to the application."""251self.set_last_state(state)252self.root.set_state(state)253self.stage = AppStage(state["app_state"]["stage"])254
255@property256def last_state(self) -> dict:257"""Returns the latest state."""258return self._last_state259
260@property261def checkpoint_dir(self) -> str:262return os.path.join(str(_storage_root_dir()), "checkpoints")263
264def remove_changes_(self, state: dict) -> None:265for _, child in state["flows"].items():266self.remove_changes(child)267state["changes"] = {}268
269def remove_changes(self, state: dict) -> dict:270state = deepcopy(state)271for _, child in state["flows"].items():272self.remove_changes_(child)273state["changes"] = {}274return state275
276def set_last_state(self, state: dict) -> None:277self._last_state = self.remove_changes(state)278
279@staticmethod280def populate_changes(last_state: dict, new_state: dict) -> dict:281diff = DeepDiff(last_state, new_state, view="tree", verbose_level=2)282
283changes_categories = [diff[key] for key in diff.to_dict()]284
285if not changes_categories:286return new_state287
288for change_category in changes_categories:289for entry in change_category:290state_el = new_state291change = entry.path(output_format="list")292if "vars" not in change:293continue294for change_el in change:295if change_el == "vars":296if "changes" not in state_el:297state_el["changes"] = {}298state_el["changes"][change[-1]] = {"from": entry.t1, "to": entry.t2}299break300# move down in the dictionary301state_el = state_el[change_el]302return new_state303
304@staticmethod305def get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> Optional[dict]:306try:307timeout = timeout or q.default_timeout308return q.get(timeout=timeout)309except queue.Empty:310return None311
312@staticmethod313def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> List[dict]:314try:315timeout = timeout or q.default_timeout316return q.batch_get(timeout=timeout, count=BATCH_DELTA_COUNT)317except queue.Empty:318return []319
320def check_error_queue(self) -> None:321exception: Exception = self.get_state_changed_from_queue(self.error_queue) # type: ignore[assignment,arg-type]322if isinstance(exception, Exception):323self.exception = exception324self.stage = AppStage.FAILED325
326@property327def flows(self) -> List[Union[LightningWork, "LightningFlow"]]:328"""Returns all the flows defined within this application."""329return [self.root] + list(self.root.flows.values())330
331@property332def works(self) -> List[LightningWork]:333"""Returns all the works defined within this application."""334return self.root.works(recurse=True)335
336@property337def named_works(self) -> List[Tuple[str, LightningWork]]:338"""Returns all the works defined within this application with their names."""339return self.root.named_works(recurse=True)340
341def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIRequest, _CommandRequest]]:342# The aggregation would try to get as many deltas as possible343# from both the `api_delta_queue` and `delta_queue`344# during the `state_accumulate_wait` time.345# The while loop can exit sooner if both queues are empty.346
347deltas = []348api_or_command_request_deltas = []349t0 = time()350
351while (time() - t0) < self.state_accumulate_wait:352# TODO: Fetch all available deltas at once to reduce queue calls.353received_deltas: List[Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta]] = (354self.batch_get_state_changed_from_queue(355self.delta_queue # type: ignore[assignment,arg-type]356)357)358if len(received_deltas) == []:359break360
361for delta in received_deltas:362if isinstance(delta, _DeltaRequest):363deltas.append(delta.delta)364elif isinstance(delta, ComponentDelta):365logger.debug(f"Received from {delta.id} : {delta.delta.to_dict()}")366work = None367try:368work = self.get_component_by_name(delta.id)369except (KeyError, AttributeError) as ex:370logger.error(f"The component {delta.id} couldn't be accessed. Exception: {ex}")371
372if work:373delta = _delta_to_app_state_delta(374self.root, # type: ignore[arg-type]375work,376deepcopy(delta.delta),377)378deltas.append(delta)379else:380api_or_command_request_deltas.append(delta)381
382if api_or_command_request_deltas:383_process_requests(self, api_or_command_request_deltas)384
385for delta in deltas:386# When aggregating deltas from the UI and the Works, and over the accumulation time window,387# it can happen that deltas from these different sources disagree. Since deltas are computed on the Work388# and UI side separately, correctness of the aggregation can only be guaranteed if both components compute389# the delta based on the same base state. But this assumption does not hold in general, and there is no way390# for the Flow to reject or resolve these deltas properly at the moment. Hence, we decide to ignore391# errors coming from deepdiff when adding deltas together by setting:392delta.log_errors = False # type: ignore[union-attr]393delta.raise_errors = False # type: ignore[union-attr]394return deltas395
396def maybe_apply_changes(self) -> Optional[bool]:397"""Get the deltas from both the flow queue and the work queue, merge the two deltas and update the state."""398self._send_flow_to_work_deltas(self.state)399
400if not self.collect_changes:401return None402
403deltas = self._collect_deltas_from_ui_and_work_queues()404
405if not deltas:406# Path and Drive aren't processed by DeepDiff, so we need to convert them to dict.407last_state = apply_to_collection(self.last_state, (Path, Drive), lambda x: x.to_dict())408state = apply_to_collection(self.state, (Path, Drive), lambda x: x.to_dict())409# When no deltas are received from the Rest API or work queues,410# we need to check if the flow modified the state and populate changes.411deep_diff = DeepDiff(last_state, state, verbose_level=2)412
413if "unprocessed" in deep_diff:414# pop the unprocessed key.415unprocessed = deep_diff.pop("unprocessed")416logger.warn(f"It seems delta differentiation resulted in {unprocessed}. Open an issue on Github.")417
418if deep_diff:419# TODO: Resolve changes with ``CacheMissException``.420# new_state = self.populate_changes(self.last_state, self.state)421self.set_last_state(self.state)422self._has_updated = True423return False424
425logger.debug(f"Received {[d.to_dict() for d in deltas]}")426
427# 2: Collect the state428state = self.state429
430# 3: Apply the state delta431for delta in deltas:432try:433state += delta434except Exception as ex:435raise Exception(f"Current State {state}, {delta.to_dict()}") from ex436
437# new_state = self.populate_changes(self.last_state, state)438self.set_state(state)439self._has_updated = True440return None441
442def run_once(self) -> bool:443"""Method used to collect changes and run the root Flow once."""444done = False445self._last_run_time = 0.0446
447if self.backend is not None:448self.backend.update_work_statuses(self.works)449
450self._update_layout()451self._update_status()452self.maybe_apply_changes()453
454if self.checkpointing and self._should_snapshot():455self._dump_checkpoint()456
457if self.stage == AppStage.BLOCKING:458return done459
460if self.stage in (AppStage.STOPPING, AppStage.FAILED):461return True462
463if self.stage == AppStage.RESTARTING:464return self._apply_restarting()465
466t0 = time()467
468try:469self.check_error_queue()470# Execute the flow only if:471# - There are state changes472# - It is the first execution of the flow473if self._has_updated:474self.root.run()475except CacheMissException:476self._on_cache_miss_exception()477except LightningFlowException:478done = True479self.stage = AppStage.FAILED480except (ExitAppException, KeyboardInterrupt):481done = True482self.stage = AppStage.STOPPING483
484if not self.ready:485self.ready = self.root.ready486
487self._last_run_time = time() - t0488
489self.on_run_once_end()490return done491
492def _reset_run_time_monitor(self) -> None:493self._run_times = [0.0] * FLOW_DURATION_SAMPLES494
495def _update_run_time_monitor(self) -> None:496self._run_times[:-1] = self._run_times[1:]497self._run_times[-1] = self._last_run_time498
499# Here we underestimate during the first FLOW_DURATION_SAMPLES500# iterations, but that's ok for our purposes501avg_elapsed_time = sum(self._run_times) / FLOW_DURATION_SAMPLES502
503if avg_elapsed_time > FLOW_DURATION_THRESHOLD:504warnings.warn(505"The execution of the `run` method of the root flow is taking too long. "506"Flow is supposed to only host coordination logic, while currently it is"507"likely to contain long-running calls, code that performs meaningful "508"computations or that makes blocking or asynchronous calls to third-party "509"services. If that is the case, you should move those pieces to a Work, "510"and make sure Flow can complete its execution in under a second.",511LightningFlowWarning,512)513
514def _run(self) -> bool:515"""Entry point of the LightningApp.516
517This would be dispatched by the Runtime objects.
518
519"""
520self._original_state = deepcopy(self.state)521done = False522
523self.ready = self.root.ready524
525self._start_with_flow_works()526
527if self.should_publish_changes_to_api and self.api_publish_state_queue is not None:528self.api_publish_state_queue.put((self.state_vars, self.status))529
530self._reset_run_time_monitor()531
532while not done:533done = self.run_once()534
535self._update_run_time_monitor()536
537if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue is not None:538self.api_publish_state_queue.put((self.state_vars, self.status))539
540self._has_updated = False541
542self._on_run_end()543
544return True545
546def _update_layout(self) -> None:547if self.backend:548self.backend.resolve_url(self, base_url=None)549
550for component in breadth_first(self.root, types=(lightning.app.LightningFlow,)): # type: ignore[arg-type]551layout = _collect_layout(self, component)552component._layout = layout553
554def _update_status(self) -> None:555old_status = self.status556
557work_statuses = {}558assert self.root is not None559for work in breadth_first(self.root, types=(lightning.app.LightningWork,)): # type: ignore[arg-type]560work_statuses[work.name] = work.status561
562self.status = AppStatus(563is_ui_ready=self.ready,564work_statuses=work_statuses,565)566
567# If the work statuses changed, the state delta will trigger an update.568# If ready has changed, we trigger an update manually.569if self.status != old_status:570self._has_updated = True571
572def _apply_restarting(self) -> bool:573self._reset_original_state()574# apply stage after restoring the original state.575self.stage = AppStage.BLOCKING576return False577
578def _has_work_finished(self, work: LightningWork) -> bool:579latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]580if latest_call_hash is None:581return False582return "ret" in work._calls[latest_call_hash]583
584def _collect_work_finish_status(self) -> dict:585work_finished_status = {work.name: self._has_work_finished(work) for work in self.works}586assert len(work_finished_status) == len(self.works)587return work_finished_status588
589def _should_snapshot(self) -> bool:590if len(self.works) == 0:591return True592if self._has_updated:593work_finished_status = self._collect_work_finish_status()594if work_finished_status:595return all(work_finished_status.values())596return True597return False598
599def state_dict(self) -> Dict:600return self.state601
602def load_state_dict(self, state: Dict) -> None:603self.set_state(state)604
605def load_state_dict_from_checkpoint_dir(606self,607checkpoints_dir: str,608version: Optional[int] = None,609) -> None:610if not os.path.exists(checkpoints_dir):611raise FileNotFoundError(f"The provided directory `{checkpoints_dir}` doesn't exist.")612checkpoints = [f for f in os.listdir(checkpoints_dir) if f.startswith("v_") and f.endswith(".json")]613if not checkpoints:614raise Exception(f"No checkpoints where found in `{checkpoints_dir}`.")615
616if version is None:617# take the latest checkpoint.618version = sorted(int(c.split("_")[1]) for c in checkpoints)[-1]619
620available_checkpoints = [c for c in checkpoints if c.startswith(f"v_{version}_")]621if not available_checkpoints:622raise FileNotFoundError(f"The version `{version}` wasn't found in {checkpoints}.")623if len(available_checkpoints) > 1:624raise Exception(f"Found 2 checkpoints `{available_checkpoints}`with the same version.")625checkpoint_path = os.path.join(checkpoints_dir, available_checkpoints[0])626with open(checkpoint_path, "rb") as fo:627state = pickle.load(fo)628self.load_state_dict(state)629
630def _dump_checkpoint(self) -> Optional[str]:631checkpoints_dir = self.checkpoint_dir632# TODO: Add supports to remotely saving checkpoints.633if checkpoints_dir.startswith("s3:"):634return None635os.makedirs(checkpoints_dir, exist_ok=True)636
637# Get all current version within the provided folder and sort them638checkpoint_versions = sorted(639int(f.split("_")[1]) for f in os.listdir(checkpoints_dir) if f.startswith("v_") and f.endswith(".json")640)641
642previous_version = checkpoint_versions[-1] if checkpoint_versions else -1643
644checkpoint_path = os.path.join(checkpoints_dir, f"v_{previous_version + 1}_{time()}.json")645
646with open(checkpoint_path, "wb") as f:647pickle.dump(self.state_dict(), f)648return checkpoint_path649
650def connect(self, runtime: "Runtime") -> None:651"""Override to customize your application to the runtime."""652pass653
654def _on_cache_miss_exception(self) -> None:655if self._has_updated:656self._update_layout()657
658def _register_schedule(self, schedule_hash: str, schedule_metadata: Dict) -> None:659# create a thread only if a user uses the flow's schedule method.660if not self._schedules:661scheduler_thread = SchedulerThread(self)662scheduler_thread.setDaemon(True)663self.threads.append(scheduler_thread)664self.threads[-1].start()665self._schedules[schedule_hash] = deepcopy(schedule_metadata)666
667def on_run_once_end(self) -> None:668if not self._schedules:669return670# disable any flow schedules.671for flow in self.flows:672flow._disable_running_schedules()673
674def _on_run_end(self) -> None:675if os.getenv("LIGHTNING_DEBUG") == "2":676del os.environ["LIGHTNING_DEBUG"]677_console.setLevel(logging.INFO)678
679@staticmethod680def _extract_vars_from_component_name(component_name: str, state: dict) -> Optional[dict]:681child = state682for child_name in component_name.split(".")[1:]:683if child_name in child["flows"]:684child = child["flows"][child_name]685elif "structures" in child and child_name in child["structures"]:686child = child["structures"][child_name]687elif child_name in child["works"]:688child = child["works"][child_name]689else:690return None691
692# Filter private keys and drives693return {694k: v695for k, v in child["vars"].items()696if (697not k.startswith("_")698and not (isinstance(v, dict) and v.get("type", None) == "__drive__")699and not (isinstance(v, (Payload, Path)))700)701}702
703def _send_flow_to_work_deltas(self, state: dict) -> None:704if not self.flow_to_work_delta_queues:705return706
707for w in self.works:708if not w.has_started:709continue710
711# Don't send changes when the state has been just sent.712if w.run.has_sent:713continue714
715state_work = self._extract_vars_from_component_name(w.name, state)716last_state_work = self._extract_vars_from_component_name(w.name, self._last_state)717
718# Note: The work was dynamically created or deleted.719if state_work is None or last_state_work is None:720continue721
722deep_diff = DeepDiff(last_state_work, state_work, verbose_level=2).to_dict()723
724if "unprocessed" in deep_diff:725deep_diff.pop("unprocessed")726
727if deep_diff:728logger.debug(f"Sending deep_diff to {w.name} : {deep_diff}")729self.flow_to_work_delta_queues[w.name].put(deep_diff)730
731def _start_with_flow_works(self) -> None:732for w in self.works:733if w._start_with_flow:734parallel = w.parallel735w._parallel = True736w.start()737w._parallel = parallel738