pytorch-lightning
93 строки · 3.1 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Any, Dict, List, Optional, Type, TypeVar
16
17import requests
18from requests import Session
19from requests.adapters import HTTPAdapter
20from urllib3.util.retry import Retry
21
22from lightning.app.components.database.utilities import _GeneralModel
23
24_CONNECTION_RETRY_TOTAL = 5
25_CONNECTION_RETRY_BACKOFF_FACTOR = 1
26
27
28def _configure_session() -> Session:
29"""Configures the session for GET and POST requests.
30
31It enables a generous retrial strategy that waits for the application server to connect.
32
33"""
34retry_strategy = Retry(
35# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
36total=_CONNECTION_RETRY_TOTAL,
37backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
38status_forcelist=[429, 500, 502, 503, 504],
39)
40adapter = HTTPAdapter(max_retries=retry_strategy)
41http = requests.Session()
42http.mount("https://", adapter)
43http.mount("http://", adapter)
44return http
45
46
47T = TypeVar("T")
48
49
50class DatabaseClient:
51def __init__(self, db_url: str, token: Optional[str] = None, model: Optional[T] = None) -> None:
52self.db_url = db_url
53self.model = model
54self.token = token or ""
55self._session = None
56
57def select_all(self, model: Optional[Type[T]] = None) -> List[T]:
58cls = model if model else self.model
59resp = self.session.post(
60self.db_url + "/select_all/", data=_GeneralModel.from_cls(cls, token=self.token).json()
61)
62assert resp.status_code == 200
63return [cls(**data) for data in resp.json()]
64
65def insert(self, model: T) -> None:
66resp = self.session.post(
67self.db_url + "/insert/",
68data=_GeneralModel.from_obj(model, token=self.token).json(),
69)
70assert resp.status_code == 200
71
72def update(self, model: T) -> None:
73resp = self.session.post(
74self.db_url + "/update/",
75data=_GeneralModel.from_obj(model, token=self.token).json(),
76)
77assert resp.status_code == 200
78
79def delete(self, model: T) -> None:
80resp = self.session.post(
81self.db_url + "/delete/",
82data=_GeneralModel.from_obj(model, token=self.token).json(),
83)
84assert resp.status_code == 200
85
86@property
87def session(self):
88if self._session is None:
89self._session = _configure_session()
90return self._session
91
92def to_dict(self) -> Dict[str, Any]:
93return {"db_url": self.db_url, "model": self.model.__name__ if self.model else None}
94