pytorch-lightning
211 строк · 7.6 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import socket
16from functools import wraps
17from typing import Any, Callable, Dict, Optional
18from urllib.parse import urljoin
19
20import requests
21
22# for backwards compatibility
23from lightning_cloud.rest_client import GridRestClient, LightningClient, create_swagger_client # noqa: F401
24from requests import Session
25from requests.adapters import HTTPAdapter
26from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
27from urllib3.util.retry import Retry
28
29from lightning.app.core import constants
30from lightning.app.utilities.app_helpers import Logger
31
32logger = Logger(__name__)
33
34
35# Global record to track ports that have been allocated in this session.
36_reserved_ports = set()
37
38
39def find_free_network_port() -> int:
40"""Finds a free port on localhost."""
41if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
42return _find_free_network_port_cloudspace()
43
44port = None
45
46for _ in range(10):
47sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
48sock.bind(("", 0))
49port = sock.getsockname()[1]
50sock.close()
51
52if port not in _reserved_ports:
53break
54
55if port in _reserved_ports:
56# Prevent an infinite loop, if we tried 10 times and didn't get a free port then something is wrong
57raise RuntimeError(
58"Couldn't find a free port. Please open an issue at `https://github.com/Lightning-AI/lightning/issues`."
59)
60
61_reserved_ports.add(port)
62return port
63
64
65def _find_free_network_port_cloudspace():
66"""Finds a free port in the exposed range when running in a cloudspace."""
67for port in range(
68constants.APP_SERVER_PORT + 1, # constants.APP_SERVER_PORT is reserved for the app server
69constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT,
70):
71if port in _reserved_ports:
72continue
73
74try:
75sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
76sock.bind(("", port))
77sock.close()
78_reserved_ports.add(port)
79return port
80except OSError:
81continue
82
83# This error should never happen. An app using this many ports would probably fail on a single machine anyway.
84raise RuntimeError(f"All {constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT} ports are already in use.")
85
86
87_CONNECTION_RETRY_TOTAL = 2880
88_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
89_DEFAULT_REQUEST_TIMEOUT = 30 # seconds
90
91
92def _configure_session() -> Session:
93"""Configures the session for GET and POST requests.
94
95It enables a generous retrial strategy that waits for the application server to connect.
96
97"""
98retry_strategy = Retry(
99# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
100total=_CONNECTION_RETRY_TOTAL,
101backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
102status_forcelist=[429, 500, 502, 503, 504],
103)
104adapter = HTTPAdapter(max_retries=retry_strategy)
105http = requests.Session()
106http.mount("https://", adapter)
107http.mount("http://", adapter)
108return http
109
110
111def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bool:
112try:
113response = requests.get(url, timeout=timeout)
114return response.status_code in (200, 404)
115except (ConnectionError, ConnectTimeout, ReadTimeout):
116logger.debug(f"The url {url} is not ready. {metadata}")
117return False
118
119
120class CustomRetryAdapter(HTTPAdapter):
121def __init__(self, *args: Any, **kwargs: Any):
122self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
123super().__init__(*args, **kwargs)
124
125def send(self, request, **kwargs: Any):
126kwargs["timeout"] = kwargs.get("timeout", self.timeout)
127return super().send(request, **kwargs)
128
129
130def _http_method_logger_wrapper(func: Callable) -> Callable:
131"""Returns the function decorated by a wrapper that logs the message using the `log_function` hook."""
132
133@wraps(func)
134def wrapped(self: "HTTPClient", *args: Any, **kwargs: Any) -> Any:
135message = f"HTTPClient: Method: {func.__name__.upper()}, Path: {args[0]}\n"
136message += f" Base URL: {self.base_url}\n"
137params = kwargs.get("query_params", {})
138if params:
139message += f" Params: {params}\n"
140resp: requests.Response = func(self, *args, **kwargs)
141message += f" Response: {resp.status_code} {resp.reason}"
142self.log_function(message)
143return resp
144
145return wrapped
146
147
148def _response(r, *args: Any, **kwargs: Any):
149return r.raise_for_status()
150
151
152class HTTPClient:
153"""A wrapper class around the requests library which handles chores like logging, retries, and timeouts
154automatically."""
155
156def __init__(
157self, base_url: str, auth_token: Optional[str] = None, log_callback: Optional[Callable] = None
158) -> None:
159self.base_url = base_url
160retry_strategy = Retry(
161# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
162# but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients
163# are going to be alive for a very long time (~ 4 days) but retries every 120 seconds
164total=_CONNECTION_RETRY_TOTAL,
165backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
166status_forcelist=[
167408, # Request Timeout
168429, # Too Many Requests
169500, # Internal Server Error
170502, # Bad Gateway
171503, # Service Unavailable
172504, # Gateway Timeout
173],
174)
175adapter = CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT)
176self.session = requests.Session()
177
178self.session.hooks = {"response": _response}
179
180self.session.mount("http://", adapter)
181self.session.mount("https://", adapter)
182
183if auth_token:
184self.session.headers.update({"Authorization": f"Bearer {auth_token}"})
185
186self.log_function = log_callback or self.log_function
187
188@_http_method_logger_wrapper
189def get(self, path: str):
190url = urljoin(self.base_url, path)
191return self.session.get(url)
192
193@_http_method_logger_wrapper
194def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None):
195url = urljoin(self.base_url, path)
196return self.session.post(url, data=data, params=query_params, json=json)
197
198@_http_method_logger_wrapper
199def delete(self, path: str):
200url = urljoin(self.base_url, path)
201return self.session.delete(url)
202
203def log_function(self, message: str, *args, **kwargs: Any):
204"""This function is used to log the messages in the client, it can be overridden by caller to customise the
205logging logic.
206
207We enabled customisation here instead of just using `logger.debug` because HTTP logging can be very noisy, but
208it is crucial for finding bugs when we have them
209
210"""
211pass
212