1
from typing import Any, Dict, TypeVar
3
from sqlalchemy import delete, insert, select, update
4
from sqlalchemy.inspection import inspect
6
from app.db.sqlalchemy import AsyncSession
12
"""CRUD operations for models."""
14
def __init__(self, session: AsyncSession, cls_model: Any):
15
self._session = session
16
self._cls_model = cls_model
18
async def create(self, *, model_data: Dict[str, Any]) -> Any:
20
query = insert(self._cls_model).values(**model_data)
22
res = await self._session.execute(query)
23
return res.inserted_primary_key
29
model_data: Dict[str, Any],
31
"""Update object by primary key."""
32
primary_key = inspect(self._cls_model).primary_key[0]
34
update(self._cls_model)
35
.where(primary_key == pkey_val)
37
.execution_options(synchronize_session="fetch")
40
await self._session.execute(query)
42
async def delete(self, *, pkey_val: Any) -> None:
43
"""Delete object by primary key value."""
44
primary_key = inspect(self._cls_model).primary_key[0].name
46
delete(self._cls_model)
47
.where(getattr(self._cls_model, primary_key) == pkey_val)
48
.execution_options(synchronize_session="fetch")
51
await self._session.execute(query)
53
async def get(self, *, pkey_val: Any) -> Any:
54
"""Get object by primary key."""
55
primary_key = inspect(self._cls_model).primary_key[0]
56
query = select(self._cls_model).where(primary_key == pkey_val)
58
rows = await self._session.execute(query)
59
return rows.scalars().one()
61
async def get_or_none(self, *, pkey_val: Any) -> Any:
62
"""Get object by primary key or none."""
63
primary_key = inspect(self._cls_model).primary_key[0]
64
query = select(self._cls_model).where(primary_key == pkey_val)
66
rows = await self._session.execute(query)
72
"""Get all objects by db model."""
73
query = select(self._cls_model)
75
rows = await self._session.execute(query)
76
return rows.scalars().all()
78
async def get_by_field(self, *, field: str, field_value: Any) -> Any:
79
"""Return objects from db with condition field=val."""
80
query = select(self._cls_model).where(
81
getattr(self._cls_model, field) == field_value
84
rows = await self._session.execute(query)
85
return rows.scalars().all()