pytorch-lightning
498 строк · 18.2 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import contextlib
17import json
18import os
19import queue
20import socket
21import sys
22import traceback
23from copy import deepcopy
24from multiprocessing import Queue
25from pathlib import Path
26from tempfile import TemporaryDirectory
27from threading import Event, Lock, Thread
28from time import sleep
29from typing import Dict, List, Mapping, Optional, Union
30
31import uvicorn
32from deepdiff import DeepDiff, Delta
33from fastapi import FastAPI, File, HTTPException, Request, Response, UploadFile, WebSocket, status
34from fastapi.middleware.cors import CORSMiddleware
35from fastapi.params import Header
36from fastapi.responses import HTMLResponse, JSONResponse
37from fastapi.staticfiles import StaticFiles
38from fastapi.templating import Jinja2Templates
39from pydantic import BaseModel
40from websockets.exceptions import ConnectionClosed
41
42from lightning.app.api.http_methods import _HttpMethod
43from lightning.app.api.request_types import _DeltaRequest
44from lightning.app.core.constants import (
45ENABLE_PULLING_STATE_ENDPOINT,
46ENABLE_PUSHING_STATE_ENDPOINT,
47ENABLE_STATE_WEBSOCKET,
48ENABLE_UPLOAD_ENDPOINT,
49FRONTEND_DIR,
50get_cloud_queue_type,
51)
52from lightning.app.core.flow import LightningFlow
53from lightning.app.core.queues import QueuingSystem
54from lightning.app.core.work import LightningWork
55from lightning.app.storage import Drive
56from lightning.app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
57from lightning.app.utilities.app_status import AppStatus
58from lightning.app.utilities.cloud import is_running_in_cloud
59from lightning.app.utilities.component import _context
60from lightning.app.utilities.enum import ComponentContext, OpenAPITags
61
62# TODO: fixed uuid for now, it will come from the FastAPI session
63TEST_SESSION_UUID = "1234"
64
65STATE_EVENT = "State changed"
66
67frontend_static_dir = os.path.join(FRONTEND_DIR, "static")
68
69api_app_delta_queue: Optional[Queue] = None
70
71template: dict = {"ui": {}, "app": {}}
72templates = Jinja2Templates(directory=FRONTEND_DIR)
73
74# TODO: try to avoid using global var for state store
75global_app_state_store = InMemoryStateStore()
76global_app_state_store.add(TEST_SESSION_UUID)
77
78lock = Lock()
79
80app_spec: Optional[List] = None
81app_status: Optional[AppStatus] = None
82app_annotations: Optional[List] = None
83
84# In the future, this would be abstracted to support horizontal scaling.
85responses_store = {}
86
87logger = Logger(__name__)
88
89# This can be replaced with a consumer that publishes states in a kv-store
90# in a serverless architecture
91
92
93class UIRefresher(Thread):
94def __init__(
95self,
96api_publish_state_queue: Queue,
97api_response_queue: Queue,
98refresh_interval: float = 0.1,
99) -> None:
100super().__init__(daemon=True)
101self.api_publish_state_queue = api_publish_state_queue
102self.api_response_queue = api_response_queue
103self._exit_event = Event()
104self.refresh_interval = refresh_interval
105
106def run(self) -> None:
107# TODO: Create multiple threads to handle the background logic
108# TODO: Investigate the use of `parallel=True`
109try:
110while not self._exit_event.is_set():
111self.run_once()
112# Note: Sleep to reduce queue calls.
113sleep(self.refresh_interval)
114except Exception as ex:
115traceback.print_exc()
116raise ex
117
118def run_once(self) -> None:
119with contextlib.suppress(queue.Empty):
120global app_status
121state, app_status = self.api_publish_state_queue.get(timeout=0)
122with lock:
123global_app_state_store.set_app_state(TEST_SESSION_UUID, state)
124
125with contextlib.suppress(queue.Empty):
126responses = self.api_response_queue.get(timeout=0)
127with lock:
128# TODO: Abstract the responses store to support horizontal scaling.
129global responses_store
130for response in responses:
131responses_store[response["id"]] = response["response"]
132
133def join(self, timeout: Optional[float] = None) -> None:
134self._exit_event.set()
135super().join(timeout)
136
137
138class StateUpdate(BaseModel):
139state: dict = {}
140
141
142openapi_tags = [
143{
144"name": OpenAPITags.APP_CLIENT_COMMAND,
145"description": "The App Endpoints to be triggered exclusively from the CLI",
146},
147{
148"name": OpenAPITags.APP_COMMAND,
149"description": "The App Endpoints that can be triggered equally from the CLI or from a Http Request",
150},
151{
152"name": OpenAPITags.APP_API,
153"description": "The App Endpoints that can be triggered exclusively from a Http Request",
154},
155]
156
157app = FastAPI(openapi_tags=openapi_tags)
158
159fastapi_service = FastAPI()
160
161fastapi_service.add_middleware(
162CORSMiddleware,
163allow_origins=["*"],
164allow_credentials=True,
165allow_methods=["*"],
166allow_headers=["*"],
167)
168
169
170# General sequence is:
171# * an update is generated in the UI
172# * the value and the location in the state (or the whole state, easier)
173# is sent to the REST API along with the session UID
174# * the previous state is loaded from the cache, the delta is generated
175# * the previous state is set as set_state, the delta is provided as
176# delta
177# * the app applies the delta and runs the entry_fn, which eventually
178# leads to another state
179# * the new state is published through the API
180# * the UI is updated with the new value of the state
181# Before the above happens, we need to refactor App so that it doesn't
182# rely on timeouts, but on sequences of updates (and alignments between
183# ranks)
184@fastapi_service.get("/api/v1/state", response_class=JSONResponse)
185async def get_state(
186response: Response,
187x_lightning_type: Optional[str] = Header(None),
188x_lightning_session_uuid: Optional[str] = Header(None),
189x_lightning_session_id: Optional[str] = Header(None),
190) -> Mapping:
191if x_lightning_session_uuid is None:
192raise Exception("Missing X-Lightning-Session-UUID header")
193if x_lightning_session_id is None:
194raise Exception("Missing X-Lightning-Session-ID header")
195
196if not ENABLE_PULLING_STATE_ENDPOINT:
197response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
198return {"status": "failure", "reason": "This endpoint is disabled."}
199
200with lock:
201x_lightning_session_uuid = TEST_SESSION_UUID
202state = global_app_state_store.get_app_state(x_lightning_session_uuid)
203global_app_state_store.set_served_state(x_lightning_session_uuid, state)
204return state
205
206
207def _get_component_by_name(component_name: str, state: dict) -> Union[LightningFlow, LightningWork]:
208child = state
209for child_name in component_name.split(".")[1:]:
210try:
211child = child["flows"][child_name]
212except KeyError:
213child = child["structures"][child_name]
214
215if isinstance(child["vars"]["_layout"], list):
216assert len(child["vars"]["_layout"]) == 1
217return child["vars"]["_layout"][0]["target"]
218return child["vars"]["_layout"]["target"]
219
220
221@fastapi_service.get("/api/v1/layout", response_class=JSONResponse)
222async def get_layout() -> str:
223with lock:
224x_lightning_session_uuid = TEST_SESSION_UUID
225state = global_app_state_store.get_app_state(x_lightning_session_uuid)
226global_app_state_store.set_served_state(x_lightning_session_uuid, state)
227layout = deepcopy(state["vars"]["_layout"])
228for la in layout:
229if la["content"].startswith("root."):
230la["content"] = _get_component_by_name(la["content"], state)
231return json.dumps(layout)
232
233
234@fastapi_service.get("/api/v1/spec", response_class=JSONResponse)
235async def get_spec(
236response: Response,
237x_lightning_session_uuid: Optional[str] = Header(None),
238x_lightning_session_id: Optional[str] = Header(None),
239) -> Union[List, Dict]:
240if x_lightning_session_uuid is None:
241raise Exception("Missing X-Lightning-Session-UUID header")
242if x_lightning_session_id is None:
243raise Exception("Missing X-Lightning-Session-ID header")
244
245if not ENABLE_PULLING_STATE_ENDPOINT:
246response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
247return {"status": "failure", "reason": "This endpoint is disabled."}
248
249global app_spec
250return app_spec or []
251
252
253@fastapi_service.post("/api/v1/delta")
254async def post_delta(
255request: Request,
256response: Response,
257x_lightning_type: Optional[str] = Header(None),
258x_lightning_session_uuid: Optional[str] = Header(None),
259x_lightning_session_id: Optional[str] = Header(None),
260) -> Optional[Dict]:
261"""This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update
262the state."""
263
264if x_lightning_session_uuid is None:
265raise Exception("Missing X-Lightning-Session-UUID header")
266if x_lightning_session_id is None:
267raise Exception("Missing X-Lightning-Session-ID header")
268
269if not ENABLE_PUSHING_STATE_ENDPOINT:
270response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
271return {"status": "failure", "reason": "This endpoint is disabled."}
272
273body: Dict = await request.json()
274assert api_app_delta_queue is not None
275api_app_delta_queue.put(_DeltaRequest(delta=Delta(body["delta"])))
276return None
277
278
279@fastapi_service.post("/api/v1/state")
280async def post_state(
281request: Request,
282response: Response,
283x_lightning_type: Optional[str] = Header(None),
284x_lightning_session_uuid: Optional[str] = Header(None),
285x_lightning_session_id: Optional[str] = Header(None),
286) -> Optional[Dict]:
287if x_lightning_session_uuid is None:
288raise Exception("Missing X-Lightning-Session-UUID header")
289if x_lightning_session_id is None:
290raise Exception("Missing X-Lightning-Session-ID header")
291# This needs to be sent so that it can be set as last state
292# in app (see sequencing above)
293# Actually: we need to make sure last_state is actually
294# the latest state seen by the UI, that is, the last state
295# ui to the UI from the API, not the last state
296# obtained by the app.
297body: Dict = await request.json()
298x_lightning_session_uuid = TEST_SESSION_UUID
299
300if not ENABLE_PUSHING_STATE_ENDPOINT:
301response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
302return {"status": "failure", "reason": "This endpoint is disabled."}
303
304if "stage" in body:
305last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
306state = deepcopy(last_state)
307state["app_state"]["stage"] = body["stage"]
308deep_diff = DeepDiff(last_state, state, verbose_level=2)
309else:
310state = body["state"]
311last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
312deep_diff = DeepDiff(last_state, state, verbose_level=2)
313assert api_app_delta_queue is not None
314api_app_delta_queue.put(_DeltaRequest(delta=Delta(deep_diff)))
315return None
316
317
318@fastapi_service.put("/api/v1/upload_file/{filename}")
319async def upload_file(response: Response, filename: str, uploaded_file: UploadFile = File(...)) -> Union[str, dict]:
320if not ENABLE_UPLOAD_ENDPOINT:
321response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
322return {"status": "failure", "reason": "This endpoint is disabled."}
323
324with TemporaryDirectory() as tmp:
325drive = Drive(
326"lit://uploaded_files",
327component_name="file_server",
328allow_duplicates=True,
329root_folder=tmp,
330)
331tmp_file = os.path.join(tmp, filename)
332
333with open(tmp_file, "wb") as f:
334done = False
335while not done:
336# Note: The 8192 number doesn't have a strong reason.
337content = await uploaded_file.read(8192)
338f.write(content)
339done = content == b""
340
341with _context(str(ComponentContext.WORK)):
342drive.put(filename)
343return f"Successfully uploaded '{filename}' to the Drive"
344
345
346@fastapi_service.get("/api/v1/status", response_model=AppStatus)
347async def get_status() -> AppStatus:
348"""Get the current status of the app and works."""
349global app_status
350if app_status is None:
351raise HTTPException(
352status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="App status hasn't been reported yet."
353)
354return app_status
355
356
357@fastapi_service.get("/api/v1/annotations", response_class=JSONResponse)
358async def get_annotations() -> Union[List, Dict]:
359"""Get the annotations associated with this app."""
360global app_annotations
361return app_annotations or []
362
363
364@fastapi_service.get("/healthz", status_code=200)
365async def healthz(response: Response) -> dict:
366"""Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
367# check the queue status only if running in cloud
368if is_running_in_cloud():
369queue_obj = QueuingSystem(get_cloud_queue_type()).get_queue(queue_name="healthz")
370# this is only being implemented on Redis Queue. For HTTP Queue, it doesn't make sense to have every single
371# app checking the status of the Queue server
372if not queue_obj.is_running:
373response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
374return {"status": "failure", "reason": "Redis is not available"}
375x_lightning_session_uuid = TEST_SESSION_UUID
376state = global_app_state_store.get_app_state(x_lightning_session_uuid)
377global_app_state_store.set_served_state(x_lightning_session_uuid, state)
378if not state:
379response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
380return {"status": "failure", "reason": f"State is empty {state}"}
381return {"status": "ok"}
382
383
384# Creates session websocket connection to notify client about any state changes
385# The websocket instance needs to be stored based on session id so it is accessible in the api layer
386@fastapi_service.websocket("/api/v1/ws")
387async def websocket_endpoint(websocket: WebSocket) -> None:
388await websocket.accept()
389if not ENABLE_STATE_WEBSOCKET:
390await websocket.close()
391return
392try:
393counter = global_app_state_store.counter
394while True:
395if global_app_state_store.counter != counter:
396await websocket.send_text(f"{global_app_state_store.counter}")
397counter = global_app_state_store.counter
398logger.debug("Updated websocket.")
399await asyncio.sleep(0.01)
400except ConnectionClosed:
401logger.debug("Websocket connection closed")
402await websocket.close()
403
404
405async def api_catch_all(request: Request, full_path: str) -> None:
406raise HTTPException(status_code=404, detail="Not found")
407
408
409# Serve frontend from a static directory using FastAPI
410fastapi_service.mount("/static", StaticFiles(directory=frontend_static_dir, check_dir=False), name="static")
411
412
413async def frontend_route(request: Request, full_path: str): # type: ignore[no-untyped-def]
414if "pytest" in sys.modules:
415return ""
416return templates.TemplateResponse("index.html", {"request": request})
417
418
419def register_global_routes() -> None:
420# Catch-all for nonexistent API routes (since we define a catch-all for client-side routing)
421fastapi_service.get("/api{full_path:path}", response_class=JSONResponse)(api_catch_all)
422fastapi_service.get("/{full_path:path}", response_class=HTMLResponse)(frontend_route)
423
424
425class LightningUvicornServer(uvicorn.Server):
426has_started_queue: Optional[Queue] = None
427
428def run(self, sockets: Optional[List[socket.socket]] = None) -> None:
429self.config.setup_event_loop()
430loop = asyncio.get_event_loop()
431asyncio.ensure_future(self.serve(sockets=sockets))
432if self.has_started_queue:
433asyncio.ensure_future(self.check_is_started(self.has_started_queue))
434loop.run_forever()
435
436async def check_is_started(self, queue: Queue) -> None:
437while not self.started:
438await asyncio.sleep(0.1)
439queue.put("SERVER_HAS_STARTED")
440
441
442def start_server(
443api_publish_state_queue: Queue,
444api_delta_queue: Queue,
445api_response_queue: Queue,
446has_started_queue: Optional[Queue] = None,
447host: str = "127.0.0.1",
448port: int = 8000,
449root_path: str = "",
450uvicorn_run: bool = True,
451spec: Optional[List] = None,
452apis: Optional[List[_HttpMethod]] = None,
453app_state_store: Optional[StateStore] = None,
454) -> UIRefresher:
455global api_app_delta_queue
456global global_app_state_store
457global app_spec
458global app_annotations
459
460app_spec = spec
461api_app_delta_queue = api_delta_queue
462
463if app_state_store is not None:
464global_app_state_store = app_state_store # type: ignore[assignment]
465
466global_app_state_store.add(TEST_SESSION_UUID)
467
468# Load annotations
469annotations_path = Path("lightning-annotations.json").resolve()
470if annotations_path.exists():
471with open(annotations_path) as f:
472app_annotations = json.load(f)
473
474refresher = UIRefresher(api_publish_state_queue, api_response_queue)
475refresher.setDaemon(True)
476refresher.start()
477
478if uvicorn_run:
479host = host.split("//")[-1] if "//" in host else host
480if host == "0.0.0.0": # noqa: S104
481logger.info("Your app has started.")
482else:
483logger.info(f"Your app has started. View it in your browser: http://{host}:{port}/view")
484if has_started_queue:
485LightningUvicornServer.has_started_queue = has_started_queue
486# uvicorn is doing some uglyness by replacing uvicorn.main by click command.
487sys.modules["uvicorn.main"].Server = LightningUvicornServer
488
489# Register the user API.
490if apis:
491for api in apis:
492api.add_route(fastapi_service, api_app_delta_queue, responses_store)
493
494register_global_routes()
495
496uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error", root_path=root_path)
497
498return refresher
499