pytorch-lightning

Форк
0
263 строки · 9.0 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import functools
16
import json
17
import pathlib
18
from typing import Any, Dict, Generic, List, Type, TypeVar
19

20
from fastapi import Response, status
21
from fastapi.encoders import jsonable_encoder
22
from lightning_utilities.core.imports import RequirementCache
23
from pydantic import BaseModel, parse_obj_as
24

25
if RequirementCache("pydantic>=2.0.0"):
26
    from pydantic.v1.main import ModelMetaclass
27
else:
28
    from pydantic.main import ModelMetaclass
29

30
from lightning.app.utilities.app_helpers import Logger
31
from lightning.app.utilities.imports import _is_sqlmodel_available
32

33
if _is_sqlmodel_available():
34
    from sqlalchemy.inspection import inspect as sqlalchemy_inspect
35
    from sqlmodel import JSON, Session, SQLModel, TypeDecorator, select
36

37
logger = Logger(__name__)
38
engine = None
39

40
T = TypeVar("T")
41

42

43
# Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
44
def _pydantic_column_type(pydantic_type: Any) -> Any:
45
    """This function enables to support JSON types with SQLModel.
46

47
    Example::
48

49
        from sqlmodel import SQLModel
50
        from sqlalchemy import Column
51

52
        class TrialConfig(SQLModel, table=False):
53
            ...
54
            params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float]))
55

56
    """
57

58
    class PydanticJSONType(TypeDecorator, Generic[T]):
59
        impl = JSON()
60

61
        def __init__(
62
            self,
63
            json_encoder=json,
64
        ):
65
            self.json_encoder = json_encoder
66
            super().__init__()
67

68
        def bind_processor(self, dialect):
69
            impl_processor = self.impl.bind_processor(dialect)
70
            dumps = self.json_encoder.dumps
71
            if impl_processor:
72

73
                def process(value: T):
74
                    if value is not None:
75
                        if isinstance(pydantic_type, ModelMetaclass):
76
                            # This allows to assign non-InDB models and if they're
77
                            # compatible, they're directly parsed into the InDB
78
                            # representation, thus hiding the implementation in the
79
                            # background. However, the InDB model will still be returned
80
                            value_to_dump = pydantic_type.from_orm(value)
81
                        else:
82
                            value_to_dump = value
83
                        value = jsonable_encoder(value_to_dump)
84
                    return impl_processor(value)
85

86
            else:
87

88
                def process(value):
89
                    if isinstance(pydantic_type, ModelMetaclass):
90
                        # This allows to assign non-InDB models and if they're
91
                        # compatible, they're directly parsed into the InDB
92
                        # representation, thus hiding the implementation in the
93
                        # background. However, the InDB model will still be returned
94
                        value_to_dump = pydantic_type.from_orm(value)
95
                    else:
96
                        value_to_dump = value
97
                    return dumps(jsonable_encoder(value_to_dump))
98

99
            return process
100

101
        def result_processor(self, dialect, coltype) -> T:
102
            impl_processor = self.impl.result_processor(dialect, coltype)
103
            if impl_processor:
104

105
                def process(value):
106
                    value = impl_processor(value)
107
                    if value is None:
108
                        return None
109

110
                    data = value
111
                    # Explicitly use the generic directly, not type(T)
112
                    return parse_obj_as(pydantic_type, data)
113

114
            else:
115

116
                def process(value):
117
                    if value is None:
118
                        return None
119

120
                    # Explicitly use the generic directly, not type(T)
121
                    return parse_obj_as(pydantic_type, value)
122

123
            return process
124

125
        def compare_values(self, x, y):
126
            return x == y
127

128
    return PydanticJSONType
129

130

131
@functools.lru_cache(maxsize=128)
132
def _get_primary_key(model_type: Type["SQLModel"]) -> str:
133
    primary_keys = sqlalchemy_inspect(model_type).primary_key
134

135
    if len(primary_keys) != 1:
136
        raise ValueError(f"The model {model_type.__name__} should have a single primary key field.")
137

138
    return primary_keys[0].name
139

140

141
class _GeneralModel(BaseModel):
142
    cls_name: str
143
    data: str
144
    token: str
145

146
    def convert_to_model(self, models: Dict[str, BaseModel]):
147
        return models[self.cls_name].parse_raw(self.data)
148

149
    @classmethod
150
    def from_obj(cls, obj, token):
151
        return cls(**{
152
            "cls_name": obj.__class__.__name__,
153
            "data": obj.json(),
154
            "token": token,
155
        })
156

157
    @classmethod
158
    def from_cls(cls, obj_cls, token):
159
        return cls(**{
160
            "cls_name": obj_cls.__name__,
161
            "data": "",
162
            "token": token,
163
        })
164

165

166
class _SelectAll:
167
    def __init__(self, models, token):
168
        print(models, token)
169
        self.models = models
170
        self.token = token
171

172
    def __call__(self, data: Dict, response: Response):
173
        if self.token and data["token"] != self.token:
174
            response.status_code = status.HTTP_401_UNAUTHORIZED
175
            return {"status": "failure", "reason": "Unauthorized request to the database."}
176

177
        with Session(engine) as session:
178
            cls: Type["SQLModel"] = self.models[data["cls_name"]]
179
            statement = select(cls)
180
            results = session.exec(statement)
181
            return results.all()
182

183

184
class _Insert:
185
    def __init__(self, models, token):
186
        self.models = models
187
        self.token = token
188

189
    def __call__(self, data: Dict, response: Response):
190
        if self.token and data["token"] != self.token:
191
            response.status_code = status.HTTP_401_UNAUTHORIZED
192
            return {"status": "failure", "reason": "Unauthorized request to the database."}
193

194
        with Session(engine) as session:
195
            ele = self.models[data["cls_name"]].parse_raw(data["data"])
196
            session.add(ele)
197
            session.commit()
198
            session.refresh(ele)
199
            return ele
200

201

202
class _Update:
203
    def __init__(self, models, token):
204
        self.models = models
205
        self.token = token
206

207
    def __call__(self, data: Dict, response: Response):
208
        if self.token and data["token"] != self.token:
209
            response.status_code = status.HTTP_401_UNAUTHORIZED
210
            return {"status": "failure", "reason": "Unauthorized request to the database."}
211

212
        with Session(engine) as session:
213
            update_data = self.models[data["cls_name"]].parse_raw(data["data"])
214
            primary_key = _get_primary_key(update_data.__class__)
215
            identifier = getattr(update_data.__class__, primary_key, None)
216
            statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
217
            results = session.exec(statement)
218
            result = results.one()
219
            for k, v in vars(update_data).items():
220
                if k in ("id", "_sa_instance_state"):
221
                    continue
222
                if getattr(result, k) != v:
223
                    setattr(result, k, v)
224
            session.add(result)
225
            session.commit()
226
            session.refresh(result)
227
            return None
228

229

230
class _Delete:
231
    def __init__(self, models, token):
232
        self.models = models
233
        self.token = token
234

235
    def __call__(self, data: Dict, response: Response):
236
        if self.token and data["token"] != self.token:
237
            response.status_code = status.HTTP_401_UNAUTHORIZED
238
            return {"status": "failure", "reason": "Unauthorized request to the database."}
239

240
        with Session(engine) as session:
241
            update_data = self.models[data["cls_name"]].parse_raw(data["data"])
242
            primary_key = _get_primary_key(update_data.__class__)
243
            identifier = getattr(update_data.__class__, primary_key, None)
244
            statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
245
            results = session.exec(statement)
246
            result = results.one()
247
            session.delete(result)
248
            session.commit()
249
            return None
250

251

252
def _create_database(db_filename: str, models: List[Type["SQLModel"]], echo: bool = False):
253
    global engine
254

255
    from sqlmodel import create_engine
256

257
    engine = create_engine(f"sqlite:///{pathlib.Path(db_filename).resolve()}", echo=echo)
258

259
    logger.debug(f"Creating the following tables {models}")
260
    try:
261
        SQLModel.metadata.create_all(engine)
262
    except Exception as ex:
263
        logger.debug(ex)
264

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.