pytorch-lightning

Форк
0
498 строк · 18.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import asyncio
16
import contextlib
17
import json
18
import os
19
import queue
20
import socket
21
import sys
22
import traceback
23
from copy import deepcopy
24
from multiprocessing import Queue
25
from pathlib import Path
26
from tempfile import TemporaryDirectory
27
from threading import Event, Lock, Thread
28
from time import sleep
29
from typing import Dict, List, Mapping, Optional, Union
30

31
import uvicorn
32
from deepdiff import DeepDiff, Delta
33
from fastapi import FastAPI, File, HTTPException, Request, Response, UploadFile, WebSocket, status
34
from fastapi.middleware.cors import CORSMiddleware
35
from fastapi.params import Header
36
from fastapi.responses import HTMLResponse, JSONResponse
37
from fastapi.staticfiles import StaticFiles
38
from fastapi.templating import Jinja2Templates
39
from pydantic import BaseModel
40
from websockets.exceptions import ConnectionClosed
41

42
from lightning.app.api.http_methods import _HttpMethod
43
from lightning.app.api.request_types import _DeltaRequest
44
from lightning.app.core.constants import (
45
    ENABLE_PULLING_STATE_ENDPOINT,
46
    ENABLE_PUSHING_STATE_ENDPOINT,
47
    ENABLE_STATE_WEBSOCKET,
48
    ENABLE_UPLOAD_ENDPOINT,
49
    FRONTEND_DIR,
50
    get_cloud_queue_type,
51
)
52
from lightning.app.core.flow import LightningFlow
53
from lightning.app.core.queues import QueuingSystem
54
from lightning.app.core.work import LightningWork
55
from lightning.app.storage import Drive
56
from lightning.app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
57
from lightning.app.utilities.app_status import AppStatus
58
from lightning.app.utilities.cloud import is_running_in_cloud
59
from lightning.app.utilities.component import _context
60
from lightning.app.utilities.enum import ComponentContext, OpenAPITags
61

62
# TODO: fixed uuid for now, it will come from the FastAPI session
63
TEST_SESSION_UUID = "1234"
64

65
STATE_EVENT = "State changed"
66

67
frontend_static_dir = os.path.join(FRONTEND_DIR, "static")
68

69
api_app_delta_queue: Optional[Queue] = None
70

71
template: dict = {"ui": {}, "app": {}}
72
templates = Jinja2Templates(directory=FRONTEND_DIR)
73

74
# TODO: try to avoid using global var for state store
75
global_app_state_store = InMemoryStateStore()
76
global_app_state_store.add(TEST_SESSION_UUID)
77

78
lock = Lock()
79

80
app_spec: Optional[List] = None
81
app_status: Optional[AppStatus] = None
82
app_annotations: Optional[List] = None
83

84
# In the future, this would be abstracted to support horizontal scaling.
85
responses_store = {}
86

87
logger = Logger(__name__)
88

89
# This can be replaced with a consumer that publishes states in a kv-store
90
# in a serverless architecture
91

92

93
class UIRefresher(Thread):
94
    def __init__(
95
        self,
96
        api_publish_state_queue: Queue,
97
        api_response_queue: Queue,
98
        refresh_interval: float = 0.1,
99
    ) -> None:
100
        super().__init__(daemon=True)
101
        self.api_publish_state_queue = api_publish_state_queue
102
        self.api_response_queue = api_response_queue
103
        self._exit_event = Event()
104
        self.refresh_interval = refresh_interval
105

106
    def run(self) -> None:
107
        # TODO: Create multiple threads to handle the background logic
108
        # TODO: Investigate the use of `parallel=True`
109
        try:
110
            while not self._exit_event.is_set():
111
                self.run_once()
112
                # Note: Sleep to reduce queue calls.
113
                sleep(self.refresh_interval)
114
        except Exception as ex:
115
            traceback.print_exc()
116
            raise ex
117

118
    def run_once(self) -> None:
119
        with contextlib.suppress(queue.Empty):
120
            global app_status
121
            state, app_status = self.api_publish_state_queue.get(timeout=0)
122
            with lock:
123
                global_app_state_store.set_app_state(TEST_SESSION_UUID, state)
124

125
        with contextlib.suppress(queue.Empty):
126
            responses = self.api_response_queue.get(timeout=0)
127
            with lock:
128
                # TODO: Abstract the responses store to support horizontal scaling.
129
                global responses_store
130
                for response in responses:
131
                    responses_store[response["id"]] = response["response"]
132

133
    def join(self, timeout: Optional[float] = None) -> None:
134
        self._exit_event.set()
135
        super().join(timeout)
136

137

138
class StateUpdate(BaseModel):
139
    state: dict = {}
140

141

142
openapi_tags = [
143
    {
144
        "name": OpenAPITags.APP_CLIENT_COMMAND,
145
        "description": "The App Endpoints to be triggered exclusively from the CLI",
146
    },
147
    {
148
        "name": OpenAPITags.APP_COMMAND,
149
        "description": "The App Endpoints that can be triggered equally from the CLI or from a Http Request",
150
    },
151
    {
152
        "name": OpenAPITags.APP_API,
153
        "description": "The App Endpoints that can be triggered exclusively from a Http Request",
154
    },
155
]
156

157
app = FastAPI(openapi_tags=openapi_tags)
158

159
fastapi_service = FastAPI()
160

161
fastapi_service.add_middleware(
162
    CORSMiddleware,
163
    allow_origins=["*"],
164
    allow_credentials=True,
165
    allow_methods=["*"],
166
    allow_headers=["*"],
167
)
168

169

170
# General sequence is:
171
# * an update is generated in the UI
172
# * the value and the location in the state (or the whole state, easier)
173
#   is sent to the REST API along with the session UID
174
# * the previous state is loaded from the cache, the delta is generated
175
# * the previous state is set as set_state, the delta is provided as
176
#   delta
177
# * the app applies the delta and runs the entry_fn, which eventually
178
#   leads to another state
179
# * the new state is published through the API
180
# * the UI is updated with the new value of the state
181
# Before the above happens, we need to refactor App so that it doesn't
182
# rely on timeouts, but on sequences of updates (and alignments between
183
# ranks)
184
@fastapi_service.get("/api/v1/state", response_class=JSONResponse)
185
async def get_state(
186
    response: Response,
187
    x_lightning_type: Optional[str] = Header(None),
188
    x_lightning_session_uuid: Optional[str] = Header(None),
189
    x_lightning_session_id: Optional[str] = Header(None),
190
) -> Mapping:
191
    if x_lightning_session_uuid is None:
192
        raise Exception("Missing X-Lightning-Session-UUID header")
193
    if x_lightning_session_id is None:
194
        raise Exception("Missing X-Lightning-Session-ID header")
195

196
    if not ENABLE_PULLING_STATE_ENDPOINT:
197
        response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
198
        return {"status": "failure", "reason": "This endpoint is disabled."}
199

200
    with lock:
201
        x_lightning_session_uuid = TEST_SESSION_UUID
202
        state = global_app_state_store.get_app_state(x_lightning_session_uuid)
203
        global_app_state_store.set_served_state(x_lightning_session_uuid, state)
204
        return state
205

206

207
def _get_component_by_name(component_name: str, state: dict) -> Union[LightningFlow, LightningWork]:
208
    child = state
209
    for child_name in component_name.split(".")[1:]:
210
        try:
211
            child = child["flows"][child_name]
212
        except KeyError:
213
            child = child["structures"][child_name]
214

215
    if isinstance(child["vars"]["_layout"], list):
216
        assert len(child["vars"]["_layout"]) == 1
217
        return child["vars"]["_layout"][0]["target"]
218
    return child["vars"]["_layout"]["target"]
219

220

221
@fastapi_service.get("/api/v1/layout", response_class=JSONResponse)
222
async def get_layout() -> str:
223
    with lock:
224
        x_lightning_session_uuid = TEST_SESSION_UUID
225
        state = global_app_state_store.get_app_state(x_lightning_session_uuid)
226
        global_app_state_store.set_served_state(x_lightning_session_uuid, state)
227
        layout = deepcopy(state["vars"]["_layout"])
228
        for la in layout:
229
            if la["content"].startswith("root."):
230
                la["content"] = _get_component_by_name(la["content"], state)
231
        return json.dumps(layout)
232

233

234
@fastapi_service.get("/api/v1/spec", response_class=JSONResponse)
235
async def get_spec(
236
    response: Response,
237
    x_lightning_session_uuid: Optional[str] = Header(None),
238
    x_lightning_session_id: Optional[str] = Header(None),
239
) -> Union[List, Dict]:
240
    if x_lightning_session_uuid is None:
241
        raise Exception("Missing X-Lightning-Session-UUID header")
242
    if x_lightning_session_id is None:
243
        raise Exception("Missing X-Lightning-Session-ID header")
244

245
    if not ENABLE_PULLING_STATE_ENDPOINT:
246
        response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
247
        return {"status": "failure", "reason": "This endpoint is disabled."}
248

249
    global app_spec
250
    return app_spec or []
251

252

253
@fastapi_service.post("/api/v1/delta")
254
async def post_delta(
255
    request: Request,
256
    response: Response,
257
    x_lightning_type: Optional[str] = Header(None),
258
    x_lightning_session_uuid: Optional[str] = Header(None),
259
    x_lightning_session_id: Optional[str] = Header(None),
260
) -> Optional[Dict]:
261
    """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update
262
    the state."""
263

264
    if x_lightning_session_uuid is None:
265
        raise Exception("Missing X-Lightning-Session-UUID header")
266
    if x_lightning_session_id is None:
267
        raise Exception("Missing X-Lightning-Session-ID header")
268

269
    if not ENABLE_PUSHING_STATE_ENDPOINT:
270
        response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
271
        return {"status": "failure", "reason": "This endpoint is disabled."}
272

273
    body: Dict = await request.json()
274
    assert api_app_delta_queue is not None
275
    api_app_delta_queue.put(_DeltaRequest(delta=Delta(body["delta"])))
276
    return None
277

278

279
@fastapi_service.post("/api/v1/state")
280
async def post_state(
281
    request: Request,
282
    response: Response,
283
    x_lightning_type: Optional[str] = Header(None),
284
    x_lightning_session_uuid: Optional[str] = Header(None),
285
    x_lightning_session_id: Optional[str] = Header(None),
286
) -> Optional[Dict]:
287
    if x_lightning_session_uuid is None:
288
        raise Exception("Missing X-Lightning-Session-UUID header")
289
    if x_lightning_session_id is None:
290
        raise Exception("Missing X-Lightning-Session-ID header")
291
    # This needs to be sent so that it can be set as last state
292
    # in app (see sequencing above)
293
    # Actually: we need to make sure last_state is actually
294
    # the latest state seen by the UI, that is, the last state
295
    # ui to the UI from the API, not the last state
296
    # obtained by the app.
297
    body: Dict = await request.json()
298
    x_lightning_session_uuid = TEST_SESSION_UUID
299

300
    if not ENABLE_PUSHING_STATE_ENDPOINT:
301
        response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
302
        return {"status": "failure", "reason": "This endpoint is disabled."}
303

304
    if "stage" in body:
305
        last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
306
        state = deepcopy(last_state)
307
        state["app_state"]["stage"] = body["stage"]
308
        deep_diff = DeepDiff(last_state, state, verbose_level=2)
309
    else:
310
        state = body["state"]
311
        last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
312
        deep_diff = DeepDiff(last_state, state, verbose_level=2)
313
    assert api_app_delta_queue is not None
314
    api_app_delta_queue.put(_DeltaRequest(delta=Delta(deep_diff)))
315
    return None
316

317

318
@fastapi_service.put("/api/v1/upload_file/{filename}")
319
async def upload_file(response: Response, filename: str, uploaded_file: UploadFile = File(...)) -> Union[str, dict]:
320
    if not ENABLE_UPLOAD_ENDPOINT:
321
        response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
322
        return {"status": "failure", "reason": "This endpoint is disabled."}
323

324
    with TemporaryDirectory() as tmp:
325
        drive = Drive(
326
            "lit://uploaded_files",
327
            component_name="file_server",
328
            allow_duplicates=True,
329
            root_folder=tmp,
330
        )
331
        tmp_file = os.path.join(tmp, filename)
332

333
        with open(tmp_file, "wb") as f:
334
            done = False
335
            while not done:
336
                # Note: The 8192 number doesn't have a strong reason.
337
                content = await uploaded_file.read(8192)
338
                f.write(content)
339
                done = content == b""
340

341
        with _context(str(ComponentContext.WORK)):
342
            drive.put(filename)
343
    return f"Successfully uploaded '{filename}' to the Drive"
344

345

346
@fastapi_service.get("/api/v1/status", response_model=AppStatus)
347
async def get_status() -> AppStatus:
348
    """Get the current status of the app and works."""
349
    global app_status
350
    if app_status is None:
351
        raise HTTPException(
352
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="App status hasn't been reported yet."
353
        )
354
    return app_status
355

356

357
@fastapi_service.get("/api/v1/annotations", response_class=JSONResponse)
358
async def get_annotations() -> Union[List, Dict]:
359
    """Get the annotations associated with this app."""
360
    global app_annotations
361
    return app_annotations or []
362

363

364
@fastapi_service.get("/healthz", status_code=200)
365
async def healthz(response: Response) -> dict:
366
    """Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
367
    # check the queue status only if running in cloud
368
    if is_running_in_cloud():
369
        queue_obj = QueuingSystem(get_cloud_queue_type()).get_queue(queue_name="healthz")
370
        # this is only being implemented on Redis Queue. For HTTP Queue, it doesn't make sense to have every single
371
        # app checking the status of the Queue server
372
        if not queue_obj.is_running:
373
            response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
374
            return {"status": "failure", "reason": "Redis is not available"}
375
    x_lightning_session_uuid = TEST_SESSION_UUID
376
    state = global_app_state_store.get_app_state(x_lightning_session_uuid)
377
    global_app_state_store.set_served_state(x_lightning_session_uuid, state)
378
    if not state:
379
        response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
380
        return {"status": "failure", "reason": f"State is empty {state}"}
381
    return {"status": "ok"}
382

383

384
# Creates session websocket connection to notify client about any state changes
385
# The websocket instance needs to be stored based on session id so it is accessible in the api layer
386
@fastapi_service.websocket("/api/v1/ws")
387
async def websocket_endpoint(websocket: WebSocket) -> None:
388
    await websocket.accept()
389
    if not ENABLE_STATE_WEBSOCKET:
390
        await websocket.close()
391
        return
392
    try:
393
        counter = global_app_state_store.counter
394
        while True:
395
            if global_app_state_store.counter != counter:
396
                await websocket.send_text(f"{global_app_state_store.counter}")
397
                counter = global_app_state_store.counter
398
                logger.debug("Updated websocket.")
399
            await asyncio.sleep(0.01)
400
    except ConnectionClosed:
401
        logger.debug("Websocket connection closed")
402
    await websocket.close()
403

404

405
async def api_catch_all(request: Request, full_path: str) -> None:
406
    raise HTTPException(status_code=404, detail="Not found")
407

408

409
# Serve frontend from a static directory using FastAPI
410
fastapi_service.mount("/static", StaticFiles(directory=frontend_static_dir, check_dir=False), name="static")
411

412

413
async def frontend_route(request: Request, full_path: str):  # type: ignore[no-untyped-def]
414
    if "pytest" in sys.modules:
415
        return ""
416
    return templates.TemplateResponse("index.html", {"request": request})
417

418

419
def register_global_routes() -> None:
420
    # Catch-all for nonexistent API routes (since we define a catch-all for client-side routing)
421
    fastapi_service.get("/api{full_path:path}", response_class=JSONResponse)(api_catch_all)
422
    fastapi_service.get("/{full_path:path}", response_class=HTMLResponse)(frontend_route)
423

424

425
class LightningUvicornServer(uvicorn.Server):
426
    has_started_queue: Optional[Queue] = None
427

428
    def run(self, sockets: Optional[List[socket.socket]] = None) -> None:
429
        self.config.setup_event_loop()
430
        loop = asyncio.get_event_loop()
431
        asyncio.ensure_future(self.serve(sockets=sockets))
432
        if self.has_started_queue:
433
            asyncio.ensure_future(self.check_is_started(self.has_started_queue))
434
        loop.run_forever()
435

436
    async def check_is_started(self, queue: Queue) -> None:
437
        while not self.started:
438
            await asyncio.sleep(0.1)
439
        queue.put("SERVER_HAS_STARTED")
440

441

442
def start_server(
443
    api_publish_state_queue: Queue,
444
    api_delta_queue: Queue,
445
    api_response_queue: Queue,
446
    has_started_queue: Optional[Queue] = None,
447
    host: str = "127.0.0.1",
448
    port: int = 8000,
449
    root_path: str = "",
450
    uvicorn_run: bool = True,
451
    spec: Optional[List] = None,
452
    apis: Optional[List[_HttpMethod]] = None,
453
    app_state_store: Optional[StateStore] = None,
454
) -> UIRefresher:
455
    global api_app_delta_queue
456
    global global_app_state_store
457
    global app_spec
458
    global app_annotations
459

460
    app_spec = spec
461
    api_app_delta_queue = api_delta_queue
462

463
    if app_state_store is not None:
464
        global_app_state_store = app_state_store  # type: ignore[assignment]
465

466
    global_app_state_store.add(TEST_SESSION_UUID)
467

468
    # Load annotations
469
    annotations_path = Path("lightning-annotations.json").resolve()
470
    if annotations_path.exists():
471
        with open(annotations_path) as f:
472
            app_annotations = json.load(f)
473

474
    refresher = UIRefresher(api_publish_state_queue, api_response_queue)
475
    refresher.setDaemon(True)
476
    refresher.start()
477

478
    if uvicorn_run:
479
        host = host.split("//")[-1] if "//" in host else host
480
        if host == "0.0.0.0":  # noqa: S104
481
            logger.info("Your app has started.")
482
        else:
483
            logger.info(f"Your app has started. View it in your browser: http://{host}:{port}/view")
484
        if has_started_queue:
485
            LightningUvicornServer.has_started_queue = has_started_queue
486
            # uvicorn is doing some uglyness by replacing uvicorn.main by click command.
487
            sys.modules["uvicorn.main"].Server = LightningUvicornServer
488

489
        # Register the user API.
490
        if apis:
491
            for api in apis:
492
                api.add_route(fastapi_service, api_app_delta_queue, responses_store)
493

494
        register_global_routes()
495

496
        uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error", root_path=root_path)
497

498
    return refresher
499

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.