pytorch-lightning
243 строки · 8.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio16import os17import sqlite318import sys19import tempfile20import threading21import traceback22from typing import List, Optional, Type, Union23
24import uvicorn25from fastapi import FastAPI26from uvicorn import run27
28from lightning.app.components.database.utilities import _create_database, _Delete, _Insert, _SelectAll, _Update29from lightning.app.core.work import LightningWork30from lightning.app.storage import Drive31from lightning.app.utilities.app_helpers import Logger32from lightning.app.utilities.imports import _is_sqlmodel_available33from lightning.app.utilities.packaging.build_config import BuildConfig34
35if _is_sqlmodel_available():36from sqlmodel import SQLModel37else:38SQLModel = object39
40
41logger = Logger(__name__)42
43
44# Required to avoid Uvicorn Server overriding Lightning App signal handlers.
45# Discussions: https://github.com/encode/uvicorn/discussions/1708
46class _DatabaseUvicornServer(uvicorn.Server):47has_started_queue = None48
49def run(self, sockets=None):50self.config.setup_event_loop()51loop = asyncio.get_event_loop()52asyncio.ensure_future(self.serve(sockets=sockets))53loop.run_forever()54
55def install_signal_handlers(self):56"""Ignore Uvicorn Signal Handlers."""57
58
59_lock = threading.Lock()60
61
62class Database(LightningWork):63def __init__(64self,65models: Union[Type["SQLModel"], List[Type["SQLModel"]]],66db_filename: str = "database.db",67store_interval: int = 10,68debug: bool = False,69) -> None:70"""The Database Component enables to interact with an SQLite database to store some structured information71about your application.
72
73The provided models are SQLModel tables
74
75Arguments:
76models: A SQLModel or a list of SQLModels table to be added to the database.
77db_filename: The name of the SQLite database.
78store_interval: Time interval (in seconds) at which the database is periodically synchronized to the Drive.
79Note that the database is also always synchronized on exit.
80debug: Whether to run the database in debug mode.
81
82Example::
83
84from typing import List
85from sqlmodel import SQLModel, Field
86from uuid import uuid4
87
88from lightning.app import LightningFlow, LightningApp
89from lightning.app.components.database import Database, DatabaseClient
90
91class CounterModel(SQLModel, table=True):
92__table_args__ = {"extend_existing": True}
93
94id: int = Field(default=None, primary_key=True)
95count: int
96
97
98class Flow(LightningFlow):
99
100def __init__(self):
101super().__init__()
102self._private_token = uuid4().hex
103self.db = Database(models=[CounterModel])
104self._client = None
105self.counter = 0
106
107def run(self):
108self.db.run(token=self._private_token)
109
110if not self.db.alive():
111return
112
113if self.counter == 0:
114self._client = DatabaseClient(
115model=CounterModel,
116db_url=self.db.url,
117token=self._private_token,
118)
119
120rows = self._client.select_all()
121
122print(f"{self.counter}: {rows}")
123
124if not rows:
125self._client.insert(CounterModel(count=0))
126else:
127row: CounterModel = rows[0]
128row.count += 1
129self._client.update(row)
130
131if self.counter >= 100:
132row: CounterModel = rows[0]
133self._client.delete(row)
134self.stop()
135
136self.counter += 1
137
138app = LightningApp(Flow())
139
140If you want to use nested SQLModels, we provide a utility to do so as follows:
141
142Example::
143
144from typing import List
145from sqlmodel import SQLModel, Field
146from sqlalchemy import Column
147
148from lightning.app.components.database.utilities import pydantic_column_type
149
150class KeyValuePair(SQLModel):
151name: str
152value: str
153
154class CounterModel(SQLModel, table=True):
155__table_args__ = {"extend_existing": True}
156
157name: int = Field(default=None, primary_key=True)
158
159# RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
160kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
161
162"""
163super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))164self.db_filename = db_filename165self._root_folder = os.path.dirname(db_filename)166self.debug = debug167self.store_interval = store_interval168self._models = models if isinstance(models, list) else [models]169self._store_thread = None170self._exit_event = None171
172def store_database(self):173try:174with tempfile.TemporaryDirectory() as tmpdir:175tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))176
177source = sqlite3.connect(self.db_filename)178dest = sqlite3.connect(tmp_db_filename)179
180source.backup(dest)181
182source.close()183dest.close()184
185drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)186drive.put(os.path.basename(tmp_db_filename))187
188logger.debug("Stored the database to the Drive.")189except Exception:190print(traceback.print_exc())191
192def periodic_store_database(self, store_interval):193while not self._exit_event.is_set():194with _lock:195self.store_database()196self._exit_event.wait(store_interval)197
198def run(self, token: Optional[str] = None) -> None:199"""200Arguments:
201token: Token used to protect the database access. Ensure you don't expose it through the App State.
202"""
203drive = Drive("lit://database", component_name=self.name, root_folder=self._root_folder)204filenames = drive.list(component_name=self.name)205if self.db_filename in filenames:206drive.get(self.db_filename)207print("Retrieved the database from Drive.")208
209app = FastAPI()210
211_create_database(self.db_filename, self._models, self.debug)212models = {m.__name__: m for m in self._models}213app.post("/select_all/")(_SelectAll(models, token))214app.post("/insert/")(_Insert(models, token))215app.post("/update/")(_Update(models, token))216app.post("/delete/")(_Delete(models, token))217
218sys.modules["uvicorn.main"].Server = _DatabaseUvicornServer219
220self._exit_event = threading.Event()221self._store_thread = threading.Thread(target=self.periodic_store_database, args=(self.store_interval,))222self._store_thread.start()223
224run(app, host=self.host, port=self.port, log_level="error")225
226def alive(self) -> bool:227"""Hack: Returns whether the server is alive."""228return self.db_url != ""229
230@property231def db_url(self) -> Optional[str]:232use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ233if use_localhost:234return self.url235ip_addr = self.public_ip or self.internal_ip236if ip_addr != "":237return f"http://{ip_addr}:{self.port}"238return ip_addr239
240def on_exit(self):241self._exit_event.set()242with _lock:243self.store_database()244