pytorch-lightning

Форк
0
243 строки · 8.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import asyncio
16
import os
17
import sqlite3
18
import sys
19
import tempfile
20
import threading
21
import traceback
22
from typing import List, Optional, Type, Union
23

24
import uvicorn
25
from fastapi import FastAPI
26
from uvicorn import run
27

28
from lightning.app.components.database.utilities import _create_database, _Delete, _Insert, _SelectAll, _Update
29
from lightning.app.core.work import LightningWork
30
from lightning.app.storage import Drive
31
from lightning.app.utilities.app_helpers import Logger
32
from lightning.app.utilities.imports import _is_sqlmodel_available
33
from lightning.app.utilities.packaging.build_config import BuildConfig
34

35
if _is_sqlmodel_available():
36
    from sqlmodel import SQLModel
37
else:
38
    SQLModel = object
39

40

41
logger = Logger(__name__)
42

43

44
# Required to avoid Uvicorn Server overriding Lightning App signal handlers.
45
# Discussions: https://github.com/encode/uvicorn/discussions/1708
46
class _DatabaseUvicornServer(uvicorn.Server):
47
    has_started_queue = None
48

49
    def run(self, sockets=None):
50
        self.config.setup_event_loop()
51
        loop = asyncio.get_event_loop()
52
        asyncio.ensure_future(self.serve(sockets=sockets))
53
        loop.run_forever()
54

55
    def install_signal_handlers(self):
56
        """Ignore Uvicorn Signal Handlers."""
57

58

59
_lock = threading.Lock()
60

61

62
class Database(LightningWork):
63
    def __init__(
64
        self,
65
        models: Union[Type["SQLModel"], List[Type["SQLModel"]]],
66
        db_filename: str = "database.db",
67
        store_interval: int = 10,
68
        debug: bool = False,
69
    ) -> None:
70
        """The Database Component enables to interact with an SQLite database to store some structured information
71
        about your application.
72

73
        The provided models are SQLModel tables
74

75
        Arguments:
76
            models: A SQLModel or a list of SQLModels table to be added to the database.
77
            db_filename: The name of the SQLite database.
78
            store_interval: Time interval (in seconds) at which the database is periodically synchronized to the Drive.
79
                            Note that the database is also always synchronized on exit.
80
            debug: Whether to run the database in debug mode.
81

82
        Example::
83

84
            from typing import List
85
            from sqlmodel import SQLModel, Field
86
            from uuid import uuid4
87

88
            from lightning.app import LightningFlow, LightningApp
89
            from lightning.app.components.database import Database, DatabaseClient
90

91
            class CounterModel(SQLModel, table=True):
92
                __table_args__ = {"extend_existing": True}
93

94
                id: int = Field(default=None, primary_key=True)
95
                count: int
96

97

98
            class Flow(LightningFlow):
99

100
                def __init__(self):
101
                    super().__init__()
102
                    self._private_token = uuid4().hex
103
                    self.db = Database(models=[CounterModel])
104
                    self._client = None
105
                    self.counter = 0
106

107
                def run(self):
108
                    self.db.run(token=self._private_token)
109

110
                    if not self.db.alive():
111
                        return
112

113
                    if self.counter == 0:
114
                        self._client = DatabaseClient(
115
                            model=CounterModel,
116
                            db_url=self.db.url,
117
                            token=self._private_token,
118
                        )
119

120
                    rows = self._client.select_all()
121

122
                    print(f"{self.counter}: {rows}")
123

124
                    if not rows:
125
                        self._client.insert(CounterModel(count=0))
126
                    else:
127
                        row: CounterModel = rows[0]
128
                        row.count += 1
129
                        self._client.update(row)
130

131
                    if self.counter >= 100:
132
                        row: CounterModel = rows[0]
133
                        self._client.delete(row)
134
                        self.stop()
135

136
                    self.counter += 1
137

138
            app = LightningApp(Flow())
139

140
        If you want to use nested SQLModels, we provide a utility to do so as follows:
141

142
        Example::
143

144
            from typing import List
145
            from sqlmodel import SQLModel, Field
146
            from sqlalchemy import Column
147

148
            from lightning.app.components.database.utilities import pydantic_column_type
149

150
            class KeyValuePair(SQLModel):
151
                name: str
152
                value: str
153

154
            class CounterModel(SQLModel, table=True):
155
                __table_args__ = {"extend_existing": True}
156

157
                name: int = Field(default=None, primary_key=True)
158

159
                # RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
160
                kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
161

162
        """
163
        super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
164
        self.db_filename = db_filename
165
        self._root_folder = os.path.dirname(db_filename)
166
        self.debug = debug
167
        self.store_interval = store_interval
168
        self._models = models if isinstance(models, list) else [models]
169
        self._store_thread = None
170
        self._exit_event = None
171

172
    def store_database(self):
173
        try:
174
            with tempfile.TemporaryDirectory() as tmpdir:
175
                tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))
176

177
                source = sqlite3.connect(self.db_filename)
178
                dest = sqlite3.connect(tmp_db_filename)
179

180
                source.backup(dest)
181

182
                source.close()
183
                dest.close()
184

185
                drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
186
                drive.put(os.path.basename(tmp_db_filename))
187

188
            logger.debug("Stored the database to the Drive.")
189
        except Exception:
190
            print(traceback.print_exc())
191

192
    def periodic_store_database(self, store_interval):
193
        while not self._exit_event.is_set():
194
            with _lock:
195
                self.store_database()
196
            self._exit_event.wait(store_interval)
197

198
    def run(self, token: Optional[str] = None) -> None:
199
        """
200
        Arguments:
201
            token: Token used to protect the database access. Ensure you don't expose it through the App State.
202
        """
203
        drive = Drive("lit://database", component_name=self.name, root_folder=self._root_folder)
204
        filenames = drive.list(component_name=self.name)
205
        if self.db_filename in filenames:
206
            drive.get(self.db_filename)
207
            print("Retrieved the database from Drive.")
208

209
        app = FastAPI()
210

211
        _create_database(self.db_filename, self._models, self.debug)
212
        models = {m.__name__: m for m in self._models}
213
        app.post("/select_all/")(_SelectAll(models, token))
214
        app.post("/insert/")(_Insert(models, token))
215
        app.post("/update/")(_Update(models, token))
216
        app.post("/delete/")(_Delete(models, token))
217

218
        sys.modules["uvicorn.main"].Server = _DatabaseUvicornServer
219

220
        self._exit_event = threading.Event()
221
        self._store_thread = threading.Thread(target=self.periodic_store_database, args=(self.store_interval,))
222
        self._store_thread.start()
223

224
        run(app, host=self.host, port=self.port, log_level="error")
225

226
    def alive(self) -> bool:
227
        """Hack: Returns whether the server is alive."""
228
        return self.db_url != ""
229

230
    @property
231
    def db_url(self) -> Optional[str]:
232
        use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
233
        if use_localhost:
234
            return self.url
235
        ip_addr = self.public_ip or self.internal_ip
236
        if ip_addr != "":
237
            return f"http://{ip_addr}:{self.port}"
238
        return ip_addr
239

240
    def on_exit(self):
241
        self._exit_event.set()
242
        with _lock:
243
            self.store_database()
244

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.