15
from abc import ABC, abstractmethod
16
from functools import partial
17
from typing import TYPE_CHECKING, Any, Callable, List, Optional
19
from lightning.app.core.queues import QueuingSystem
20
from lightning.app.utilities.proxies import ProxyWorkRun, unwrap
27
"""The Backend provides and interface for the framework to communicate with resources in the cloud."""
29
def __init__(self, entrypoint_file: str, queues: QueuingSystem, queue_id: str) -> None:
30
self.queues: QueuingSystem = queues
31
self.queue_id = queue_id
32
self.entrypoint_file = entrypoint_file
35
def create_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
39
def update_work_statuses(self, works: List["lightning.app.LightningWork"]) -> None:
43
def stop_all_works(self, works: List["lightning.app.LightningWork"]) -> None:
47
def resolve_url(self, app, base_url: Optional[str] = None) -> None:
51
def stop_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
54
def _dynamic_run_wrapper(
57
app: "lightning.app.LightningApp",
58
work: "lightning.app.LightningWork",
65
f"Failed to create process for {work.__class__.__name__}."
66
f" Make sure to set this work as an attribute of a `LightningFlow` before calling the run method."
70
self._register_queues(app, work)
75
self.create_work(app, work)
81
work.run = ProxyWorkRun(
85
caller_queue=app.caller_queues[work.name],
89
return work.run(*args, **kwargs)
91
def _wrap_run_method(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork"):
92
if work.run.__name__ == "_dynamic_run_wrapper":
95
work.run = partial(self._dynamic_run_wrapper, app=app, work=work, work_run=unwrap(work.run))
97
def _prepare_queues(self, app: "lightning.app.LightningApp"):
98
kw = {"queue_id": self.queue_id}
99
app.delta_queue = self.queues.get_delta_queue(**kw)
100
app.readiness_queue = self.queues.get_readiness_queue(**kw)
101
app.api_response_queue = self.queues.get_api_response_queue(**kw)
102
app.error_queue = self.queues.get_error_queue(**kw)
103
app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw)
104
app.api_delta_queue = app.delta_queue
105
app.request_queues = {}
106
app.response_queues = {}
107
app.copy_request_queues = {}
108
app.copy_response_queues = {}
109
app.caller_queues = {}
111
app.flow_to_work_delta_queues = {}
113
def _register_queues(self, app, work):
114
kw = {"queue_id": self.queue_id, "work_name": work.name}
115
app.request_queues.update({work.name: self.queues.get_orchestrator_request_queue(**kw)})
116
app.response_queues.update({work.name: self.queues.get_orchestrator_response_queue(**kw)})
117
app.copy_request_queues.update({work.name: self.queues.get_orchestrator_copy_request_queue(**kw)})
118
app.copy_response_queues.update({work.name: self.queues.get_orchestrator_copy_response_queue(**kw)})
119
app.caller_queues.update({work.name: self.queues.get_caller_queue(**kw)})
120
app.flow_to_work_delta_queues.update({work.name: self.queues.get_flow_to_work_delta_queue(**kw)})
123
class WorkManager(ABC):
124
"""The work manager is an interface for the backend, runtime to control the LightningWork."""
126
def __init__(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork"):
130
def start(self) -> None:
134
def kill(self) -> None:
138
def restart(self) -> None:
142
def is_alive(self) -> bool: