pytorch-lightning

Форк
0
211 строк · 7.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import socket
16
from functools import wraps
17
from typing import Any, Callable, Dict, Optional
18
from urllib.parse import urljoin
19

20
import requests
21

22
# for backwards compatibility
23
from lightning_cloud.rest_client import GridRestClient, LightningClient, create_swagger_client  # noqa: F401
24
from requests import Session
25
from requests.adapters import HTTPAdapter
26
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
27
from urllib3.util.retry import Retry
28

29
from lightning.app.core import constants
30
from lightning.app.utilities.app_helpers import Logger
31

32
logger = Logger(__name__)
33

34

35
# Global record to track ports that have been allocated in this session.
36
_reserved_ports = set()
37

38

39
def find_free_network_port() -> int:
40
    """Finds a free port on localhost."""
41
    if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
42
        return _find_free_network_port_cloudspace()
43

44
    port = None
45

46
    for _ in range(10):
47
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
48
        sock.bind(("", 0))
49
        port = sock.getsockname()[1]
50
        sock.close()
51

52
        if port not in _reserved_ports:
53
            break
54

55
    if port in _reserved_ports:
56
        # Prevent an infinite loop, if we tried 10 times and didn't get a free port then something is wrong
57
        raise RuntimeError(
58
            "Couldn't find a free port. Please open an issue at `https://github.com/Lightning-AI/lightning/issues`."
59
        )
60

61
    _reserved_ports.add(port)
62
    return port
63

64

65
def _find_free_network_port_cloudspace():
66
    """Finds a free port in the exposed range when running in a cloudspace."""
67
    for port in range(
68
        constants.APP_SERVER_PORT + 1,  # constants.APP_SERVER_PORT is reserved for the app server
69
        constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT,
70
    ):
71
        if port in _reserved_ports:
72
            continue
73

74
        try:
75
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
76
            sock.bind(("", port))
77
            sock.close()
78
            _reserved_ports.add(port)
79
            return port
80
        except OSError:
81
            continue
82

83
    # This error should never happen. An app using this many ports would probably fail on a single machine anyway.
84
    raise RuntimeError(f"All {constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT} ports are already in use.")
85

86

87
_CONNECTION_RETRY_TOTAL = 2880
88
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
89
_DEFAULT_REQUEST_TIMEOUT = 30  # seconds
90

91

92
def _configure_session() -> Session:
93
    """Configures the session for GET and POST requests.
94

95
    It enables a generous retrial strategy that waits for the application server to connect.
96

97
    """
98
    retry_strategy = Retry(
99
        # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
100
        total=_CONNECTION_RETRY_TOTAL,
101
        backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
102
        status_forcelist=[429, 500, 502, 503, 504],
103
    )
104
    adapter = HTTPAdapter(max_retries=retry_strategy)
105
    http = requests.Session()
106
    http.mount("https://", adapter)
107
    http.mount("http://", adapter)
108
    return http
109

110

111
def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bool:
112
    try:
113
        response = requests.get(url, timeout=timeout)
114
        return response.status_code in (200, 404)
115
    except (ConnectionError, ConnectTimeout, ReadTimeout):
116
        logger.debug(f"The url {url} is not ready. {metadata}")
117
        return False
118

119

120
class CustomRetryAdapter(HTTPAdapter):
121
    def __init__(self, *args: Any, **kwargs: Any):
122
        self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
123
        super().__init__(*args, **kwargs)
124

125
    def send(self, request, **kwargs: Any):
126
        kwargs["timeout"] = kwargs.get("timeout", self.timeout)
127
        return super().send(request, **kwargs)
128

129

130
def _http_method_logger_wrapper(func: Callable) -> Callable:
131
    """Returns the function decorated by a wrapper that logs the message using the `log_function` hook."""
132

133
    @wraps(func)
134
    def wrapped(self: "HTTPClient", *args: Any, **kwargs: Any) -> Any:
135
        message = f"HTTPClient: Method: {func.__name__.upper()}, Path: {args[0]}\n"
136
        message += f"      Base URL: {self.base_url}\n"
137
        params = kwargs.get("query_params", {})
138
        if params:
139
            message += f"      Params: {params}\n"
140
        resp: requests.Response = func(self, *args, **kwargs)
141
        message += f"      Response: {resp.status_code} {resp.reason}"
142
        self.log_function(message)
143
        return resp
144

145
    return wrapped
146

147

148
def _response(r, *args: Any, **kwargs: Any):
149
    return r.raise_for_status()
150

151

152
class HTTPClient:
153
    """A wrapper class around the requests library which handles chores like logging, retries, and timeouts
154
    automatically."""
155

156
    def __init__(
157
        self, base_url: str, auth_token: Optional[str] = None, log_callback: Optional[Callable] = None
158
    ) -> None:
159
        self.base_url = base_url
160
        retry_strategy = Retry(
161
            # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
162
            # but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients
163
            # are going to be alive for a very long time (~ 4 days) but retries every 120 seconds
164
            total=_CONNECTION_RETRY_TOTAL,
165
            backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
166
            status_forcelist=[
167
                408,  # Request Timeout
168
                429,  # Too Many Requests
169
                500,  # Internal Server Error
170
                502,  # Bad Gateway
171
                503,  # Service Unavailable
172
                504,  # Gateway Timeout
173
            ],
174
        )
175
        adapter = CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT)
176
        self.session = requests.Session()
177

178
        self.session.hooks = {"response": _response}
179

180
        self.session.mount("http://", adapter)
181
        self.session.mount("https://", adapter)
182

183
        if auth_token:
184
            self.session.headers.update({"Authorization": f"Bearer {auth_token}"})
185

186
        self.log_function = log_callback or self.log_function
187

188
    @_http_method_logger_wrapper
189
    def get(self, path: str):
190
        url = urljoin(self.base_url, path)
191
        return self.session.get(url)
192

193
    @_http_method_logger_wrapper
194
    def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None):
195
        url = urljoin(self.base_url, path)
196
        return self.session.post(url, data=data, params=query_params, json=json)
197

198
    @_http_method_logger_wrapper
199
    def delete(self, path: str):
200
        url = urljoin(self.base_url, path)
201
        return self.session.delete(url)
202

203
    def log_function(self, message: str, *args, **kwargs: Any):
204
        """This function is used to log the messages in the client, it can be overridden by caller to customise the
205
        logging logic.
206

207
        We enabled customisation here instead of just using `logger.debug` because HTTP logging can be very noisy, but
208
        it is crucial for finding bugs when we have them
209

210
        """
211
        pass
212

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.