pytorch-lightning

Форк
0
93 строки · 3.1 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from typing import Any, Dict, List, Optional, Type, TypeVar
16

17
import requests
18
from requests import Session
19
from requests.adapters import HTTPAdapter
20
from urllib3.util.retry import Retry
21

22
from lightning.app.components.database.utilities import _GeneralModel
23

24
_CONNECTION_RETRY_TOTAL = 5
25
_CONNECTION_RETRY_BACKOFF_FACTOR = 1
26

27

28
def _configure_session() -> Session:
29
    """Configures the session for GET and POST requests.
30

31
    It enables a generous retrial strategy that waits for the application server to connect.
32

33
    """
34
    retry_strategy = Retry(
35
        # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
36
        total=_CONNECTION_RETRY_TOTAL,
37
        backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
38
        status_forcelist=[429, 500, 502, 503, 504],
39
    )
40
    adapter = HTTPAdapter(max_retries=retry_strategy)
41
    http = requests.Session()
42
    http.mount("https://", adapter)
43
    http.mount("http://", adapter)
44
    return http
45

46

47
T = TypeVar("T")
48

49

50
class DatabaseClient:
51
    def __init__(self, db_url: str, token: Optional[str] = None, model: Optional[T] = None) -> None:
52
        self.db_url = db_url
53
        self.model = model
54
        self.token = token or ""
55
        self._session = None
56

57
    def select_all(self, model: Optional[Type[T]] = None) -> List[T]:
58
        cls = model if model else self.model
59
        resp = self.session.post(
60
            self.db_url + "/select_all/", data=_GeneralModel.from_cls(cls, token=self.token).json()
61
        )
62
        assert resp.status_code == 200
63
        return [cls(**data) for data in resp.json()]
64

65
    def insert(self, model: T) -> None:
66
        resp = self.session.post(
67
            self.db_url + "/insert/",
68
            data=_GeneralModel.from_obj(model, token=self.token).json(),
69
        )
70
        assert resp.status_code == 200
71

72
    def update(self, model: T) -> None:
73
        resp = self.session.post(
74
            self.db_url + "/update/",
75
            data=_GeneralModel.from_obj(model, token=self.token).json(),
76
        )
77
        assert resp.status_code == 200
78

79
    def delete(self, model: T) -> None:
80
        resp = self.session.post(
81
            self.db_url + "/delete/",
82
            data=_GeneralModel.from_obj(model, token=self.token).json(),
83
        )
84
        assert resp.status_code == 200
85

86
    @property
87
    def session(self):
88
        if self._session is None:
89
            self._session = _configure_session()
90
        return self._session
91

92
    def to_dict(self) -> Dict[str, Any]:
93
        return {"db_url": self.db_url, "model": self.model.__name__ if self.model else None}
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.