lavkach3

Форк
0
329 строк · 12.8 Кб
1
import asyncio
2
import logging
3
import uuid
4
from collections import defaultdict
5
from dataclasses import dataclass
6
from typing import Any, Generic, Optional, Type, TypeVar, Tuple
7
from uuid import uuid4
8
from starlette.requests import Request
9
from fastapi_filter.contrib.sqlalchemy import Filter
10
from httpx import AsyncClient as asyncclient
11
from pydantic import BaseModel
12
from sqlalchemy import select, Row, RowMapping, inspect
13
from sqlalchemy.exc import IntegrityError, InvalidRequestError
14
from sqlalchemy.ext.asyncio import AsyncSession
15
from starlette.exceptions import HTTPException
16

17
from core.db import Base
18
from core.db.session import Base, session
19
from core.fastapi.middlewares.authentication import CurrentUser
20
from core.helpers.broker import list_brocker
21
from core.helpers.cache import CacheTag
22
from core.schemas import BaseFilter
23
from core.service_config import config
24

25
ModelType = TypeVar("ModelType", bound=Base)
26
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
27
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
28
FilterSchemaType = TypeVar("FilterSchemaType", bound=Filter)
29
before_fields = ['role_ids', 'company_ids', 'is_admin', 'store_id']
30

31
logging.basicConfig(level=logging.INFO)
32
logger = logging.getLogger(__name__)
33
from types import FunctionType
34

35
@dataclass
36
class Model:
37
    service: object
38
    model: object
39

40

41
def import_service(service_name):
42
    components = service_name.split('.')
43
    mod = __import__(components[0])
44
    for comp in components[1:]:
45
        mod = getattr(mod, comp)
46
    return mod
47

48

49
def is_pydantic(obj: object | list):
50
    """ Checks whether an object is pydantic. """
51
    if isinstance(obj, list):
52
        for i in obj:
53
            return type(i).__class__.__name__ == "ModelMetaclass"
54
    return type(obj).__class__.__name__ == "ModelMetaclass"
55

56

57
def model_to_entity(schema):
58
    """
59
        Iterates through pydantic schema and parses nested schemas
60
        to a dictionary containing SQLAlchemy models.
61
        Only works if nested schemas have specified the Meta.orm_model.
62
    """
63
    if is_pydantic(schema):
64
        try:
65
            converted_model = model_to_entity(dict(schema))
66
            return schema.Config.orm_model(**converted_model)
67
        except AttributeError:
68
            model_name = schema.__class__.__name__
69
            raise AttributeError(f"Failed converting pydantic model: {model_name}.Meta.orm_model not specified.")
70

71
    elif isinstance(schema, list):
72
        return [model_to_entity(model) for model in schema]
73

74
    elif isinstance(schema, dict):
75
        for key, model in schema.items():
76
            schema[key] = model_to_entity(model)
77

78
    return schema
79

80

81
class LocalCache:
82

83
    def __init__(self):
84
        self._data = defaultdict(defaultdict)
85

86

87
localcache = LocalCache()
88

89

90
class BaseCache:
91
    service: 'BaseService'
92

93
    def __init__(self, service: 'BaseService'):
94
        self.service = service
95
        self.cache = localcache._data
96

97
    def get(self, id: uuid.UUID):
98
        return self.cache[self.service.model.__tablename__].get(id)
99

100
    def set(self, sql_obj: list | object):
101
        if isinstance(sql_obj, list):
102
            for obj in sql_obj:
103
                self.set(obj)
104
            return sql_obj
105
        self.cache[self.service.model.__tablename__][sql_obj.id] = sql_obj  # type: ignore
106
        return sql_obj.id  # type: ignore
107

108
    def delete(self, id: uuid.UUID):
109
        self.cache[self.service.model.__tablename__].pop(id, False)
110
        return True
111

112

113
class BaseService(Generic[ModelType, CreateSchemaType, UpdateSchemaType, FilterSchemaType]):
114
    def __init__(
115
            self,
116
            request: Request,
117
            model: Type[ModelType],
118
            create_schema: Type[CreateSchemaType],
119
            update_schema: Type[UpdateSchemaType],
120
            **kwargs
121
    ):
122
        if isinstance(request, CurrentUser):
123
            self.user = CurrentUser
124
        else:
125
            self.user = request.user
126
        self.model = model
127
        self.create_schema = create_schema
128
        self.update_schema = update_schema
129
        self.request = Request
130
        self.env = request.scope['env']
131
        self.session = session
132
        self.basecache = BaseCache(self)
133

134
    def sudo(self):
135
        self.user = CurrentUser(id=uuid4(), is_admin=True)
136
        return self
137

138
    async def _get(self, id: Any, for_update=False) -> Row | RowMapping:
139
        query = select(self.model).where(self.model.id == id)
140
        if for_update:
141
            query.with_for_update()
142
        if self.user.is_admin:
143
            query = select(self.model).where(self.model.id == id)
144
        result = await self.session.execute(query)
145
        entity = result.scalars().first()
146
        if not entity:
147
            raise HTTPException(status_code=404, detail=f"Not found")
148
        return entity
149

150
    async def get(self, id: Any, for_update=False):
151
        entity = await self._get(id, for_update)
152
        return entity
153

154
    async def _list(self, _filter: FilterSchemaType | dict, size: int = 100):
155
        if not isinstance(_filter, BaseFilter):
156
            if isinstance(_filter, dict):
157
                _filter = self.env[self.model.__tablename__].schemas.filter(**_filter)
158
        if not self.user.company_id is False:
159
            if self.model.__tablename__ not in ('company', 'user', 'bus'):
160
                setattr(_filter, 'company_id__in', [self.user.company_id])
161
        query_filter = _filter.filter(select(self.model)).limit(size)  # type: ignore
162
        if getattr(_filter, 'order_by'):
163
            query_filter = _filter.sort(query_filter)  # type: ignore
164
        executed_data = await self.session.execute(query_filter)
165
        result = executed_data.scalars().all()
166
        return result
167

168
    async def list(self, _filter: FilterSchemaType | dict, size: int = 100):
169
        entitys = await self._list(_filter, size)
170
        # self.basecache.set(entitys)
171
        return entitys
172

173
    async def _create(self, obj: CreateSchemaType | dict, commit=True) -> ModelType:
174
        if isinstance(obj, dict):
175
            try:
176
                obj = self.create_schema(**obj)
177
            except Exception as ex:
178
                raise HTTPException(status_code=422, detail=str(ex))
179
        to_set = []
180
        exclude_rel = []
181
        # exclude_rel = list(obj.model_extra.keys())
182
        relcations_to_create = []
183
        for key, value in obj.__dict__.items():
184
            if is_pydantic(value):
185
                if isinstance(value, list):
186
                    for _obj in value:
187
                        rel_service = self.env[_obj.Config.orm_model.__tablename__].service
188
                        if hasattr(_obj, 'id') and _obj.id:
189
                            rel_entity = await rel_service.update(id=_obj.id, obj=_obj, commit=False)
190
                        else:
191
                            _dump = _obj.model_dump()
192
                            create_obj = rel_service.create_schema(**_dump)
193
                            relcations_to_create.append((rel_service.create, create_obj))
194
                        exclude_rel.append(key)
195
                else:
196
                    pass  # TODO: дописать такую логику где не list а model
197
            else:
198
                to_set.append((key, value))
199
        entity = self.model(**obj.model_dump(exclude=exclude_rel))
200
        entity.company_id = self.user.company_id if not hasattr(obj, 'company_id') else obj.company_id
201
        self.session.add(entity)
202
        if commit:
203
            try:
204
                await self.session.commit()
205
                await self.session.refresh(entity)
206
                for _rel_method, _rel_dump in relcations_to_create:
207
                    await self.session.refresh(entity)
208
                    setattr(_rel_dump, f'{self.model.__tablename__}_id', entity.id)
209
                    await _rel_method(obj=_rel_dump, commit=True, parent=entity)
210
                await self.session.refresh(entity)
211
            except IntegrityError as e:
212
                await self.session.rollback()
213
                if "duplicate key" in str(e):
214
                    raise HTTPException(status_code=409, detail=f"Conflict Error entity {str(e)}")
215
                else:
216
                    raise HTTPException(status_code=500, detail=f"ERROR:  {str(e)}")
217
            except TimeoutError as e:
218
                await asyncio.sleep(1)
219
                await self.session.refresh(entity)
220
            except Exception as e:
221
                raise HTTPException(status_code=409, detail=f"Conflict Error entity {str(e)}")
222
        else:
223
            await self.session.flush([entity])
224
        return entity
225

226
    async def create(self, obj: CreateSchemaType | dict, commit=True) -> ModelType:
227
        entity = await self._create(obj, commit=commit)
228
        if self.model.__tablename__ == 'bus':
229
            return entity
230
        await entity.notify('create')
231
        return entity
232

233
    async def _update(self, id: Any, obj: UpdateSchemaType, commit=True) -> tuple[Row, list]:
234
        entity: Row = await self.get(id, for_update=True)
235
        if not entity:
236
            raise HTTPException(status_code=404, detail=f"Not Found with id {id}")
237

238
        to_set: list = []
239
        updated_fields: list = []
240
        for key, value in obj.__dict__.items():
241
            if key in obj.model_fields_set:
242
                obj_value = getattr(obj, key)
243
                if is_pydantic(obj_value):
244
                    for _obj in obj_value:
245
                        rel_service = self.env[_obj.Config.orm_model.__tablename__].service
246
                        #rel = rel_service(self.request)
247
                        if hasattr(_obj, 'id') and getattr(_obj, 'id'):
248
                            rel_entity, _updated_fields = await rel_service._update(id=_obj.id, obj=_obj, commit=False)
249
                            self.session.add(rel_entity)
250
                            if _updated_fields:
251
                                updated_fields.append(key)
252
                        else:
253
                            _dump = _obj.model_dump()
254
                            _dump[f'{self.model.__tablename__}_id'] = id
255
                            create_obj = rel_service.create_schema(**_dump)
256
                            await rel_service._create(obj=create_obj, parent=entity, commit=False)
257
                else:
258
                    if key == 'id':
259
                        value = id
260
                    to_set.append((key, value))
261
        await self.session.refresh(entity)
262
        for k, v in to_set:
263
            attr = getattr(entity, k)
264
            if not attr == v:
265
                setattr(entity, k, v)
266
                updated_fields.append(k)
267
        # entity.mode_list_rel = new_entity.move_list_rel
268
        try:
269
            self.session.add(entity)
270
        except InvalidRequestError as ex:
271
            logger.warning(ex)
272
        if commit:
273
            try:
274
                await self.session.commit()
275
            except IntegrityError as e:
276
                await self.session.rollback()
277
                if "duplicate key" in str(e):
278
                    raise HTTPException(status_code=409, detail=f"Conflict Error entity {str(e)}")
279
                else:
280
                    raise HTTPException(status_code=500, detail=f"ERROR:  {str(e)}")
281
            except Exception as e:
282
                raise HTTPException(status_code=500, detail=f"ERROR:  {str(e)}")
283
            await self.session.flush()
284
            return await self.get(id), updated_fields
285
        return entity, updated_fields
286

287
    async def prepere_bus(self, entity: ModelType, method: str) -> dict:
288
        return {
289
            'cache_tag': CacheTag.MODEL,
290
            'message': f'{self.model.__tablename__.capitalize()} is {method.capitalize()}',
291
            'company_id': entity.company_id if hasattr(entity,  'company_id') else entity.id,
292
            'vars': {
293
                'id': entity.id,
294
                'lsn': entity.lsn,
295
                'model': self.model.__tablename__,
296
                'method': method,
297
            }
298
        }
299

300

301
    async def update(self, id: Any, obj: UpdateSchemaType, commit=True) -> Row:
302
        entity, updated_fields = await self._update(id, obj, commit=commit)
303
        await entity.notify('update', updated_fields)
304
        return entity
305

306
    async def _delete(self, id: Any):
307
        entity = await self.get(id)
308
        message = await self.prepere_bus(entity, 'delete')
309
        await self.session.delete(entity)
310
        try:
311
            await self.session.commit()
312
        except IntegrityError as e:
313
            await self.session.rollback()
314
            if "duplicate key" in str(e):
315
                raise HTTPException(status_code=409, detail=f"Conflict Error entity {str(e)}")
316
            else:
317
                raise HTTPException(status_code=500, detail=f"ERROR:  {str(e)}")
318
        except Exception as e:
319
            raise HTTPException(status_code=500, detail=f"ERROR:  {str(e)}")
320
        return True, message
321

322
    async def delete(self, id: Any) -> bool:
323
        res, message = await self._delete(id)
324
        await self.model.notify(self.model, method='delete', message=message)
325
        return res
326

327
    @classmethod
328
    def init(cls, request: Request):
329
        return cls(request)
330

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.