pytorch-lightning
574 строки · 21.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import base6416import multiprocessing17import pickle18import queue # needed as import instead from/import for mocking in tests19import time20import warnings21from abc import ABC, abstractmethod22from enum import Enum23from pathlib import Path24from typing import Any, List, Optional, Tuple25from urllib.parse import urljoin26
27import backoff28import requests29from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout30
31from lightning.app.core.constants import (32BATCH_DELTA_COUNT,33HTTP_QUEUE_REFRESH_INTERVAL,34HTTP_QUEUE_REQUESTS_PER_SECOND,35HTTP_QUEUE_TOKEN,36HTTP_QUEUE_URL,37LIGHTNING_DIR,38QUEUE_DEBUG_ENABLED,39REDIS_HOST,40REDIS_PASSWORD,41REDIS_PORT,42REDIS_QUEUES_READ_DEFAULT_TIMEOUT,43STATE_UPDATE_TIMEOUT,44WARNING_QUEUE_SIZE,45)
46from lightning.app.utilities.app_helpers import Logger47from lightning.app.utilities.imports import _is_redis_available, requires48from lightning.app.utilities.network import HTTPClient49
50if _is_redis_available():51import redis52
53logger = Logger(__name__)54
55
56READINESS_QUEUE_CONSTANT = "READINESS_QUEUE"57ERROR_QUEUE_CONSTANT = "ERROR_QUEUE"58DELTA_QUEUE_CONSTANT = "DELTA_QUEUE"59HAS_SERVER_STARTED_CONSTANT = "HAS_SERVER_STARTED_QUEUE"60CALLER_QUEUE_CONSTANT = "CALLER_QUEUE"61API_STATE_PUBLISH_QUEUE_CONSTANT = "API_STATE_PUBLISH_QUEUE"62API_DELTA_QUEUE_CONSTANT = "API_DELTA_QUEUE"63API_REFRESH_QUEUE_CONSTANT = "API_REFRESH_QUEUE"64ORCHESTRATOR_REQUEST_CONSTANT = "ORCHESTRATOR_REQUEST"65ORCHESTRATOR_RESPONSE_CONSTANT = "ORCHESTRATOR_RESPONSE"66ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST"67ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE"68WORK_QUEUE_CONSTANT = "WORK_QUEUE"69API_RESPONSE_QUEUE_CONSTANT = "API_RESPONSE_QUEUE"70FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT = "FLOW_TO_WORKS_DELTA_QUEUE"71
72
73class QueuingSystem(Enum):74MULTIPROCESS = "multiprocess"75REDIS = "redis"76HTTP = "http"77
78def get_queue(self, queue_name: str) -> "BaseQueue":79if self == QueuingSystem.MULTIPROCESS:80return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)81if self == QueuingSystem.REDIS:82return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)83return RateLimitedQueue(84HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND85)86
87def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":88queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT89return self.get_queue(queue_name)90
91def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":92queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT93return self.get_queue(queue_name)94
95def get_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":96queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT97return self.get_queue(queue_name)98
99def get_error_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":100queue_name = f"{queue_id}_{ERROR_QUEUE_CONSTANT}" if queue_id else ERROR_QUEUE_CONSTANT101return self.get_queue(queue_name)102
103def get_has_server_started_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":104queue_name = f"{queue_id}_{HAS_SERVER_STARTED_CONSTANT}" if queue_id else HAS_SERVER_STARTED_CONSTANT105return self.get_queue(queue_name)106
107def get_caller_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":108queue_name = (109f"{queue_id}_{CALLER_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{CALLER_QUEUE_CONSTANT}_{work_name}"110)111return self.get_queue(queue_name)112
113def get_api_state_publish_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":114queue_name = f"{queue_id}_{API_STATE_PUBLISH_QUEUE_CONSTANT}" if queue_id else API_STATE_PUBLISH_QUEUE_CONSTANT115return self.get_queue(queue_name)116
117# TODO: This is hack, so we can remove this queue entirely when fully optimized.118def get_api_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":119queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT120return self.get_queue(queue_name)121
122def get_orchestrator_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":123queue_name = (124f"{queue_id}_{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"125if queue_id126else f"{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"127)128return self.get_queue(queue_name)129
130def get_orchestrator_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":131queue_name = (132f"{queue_id}_{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"133if queue_id134else f"{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"135)136return self.get_queue(queue_name)137
138def get_orchestrator_copy_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":139queue_name = (140f"{queue_id}_{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"141if queue_id142else f"{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"143)144return self.get_queue(queue_name)145
146def get_orchestrator_copy_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":147queue_name = (148f"{queue_id}_{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"149if queue_id150else f"{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"151)152return self.get_queue(queue_name)153
154def get_work_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":155queue_name = (156f"{queue_id}_{WORK_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{WORK_QUEUE_CONSTANT}_{work_name}"157)158return self.get_queue(queue_name)159
160def get_flow_to_work_delta_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":161queue_name = (162f"{queue_id}_{FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT}_{work_name}"163if queue_id164else f"{FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT}_{work_name}"165)166return self.get_queue(queue_name)167
168
169class BaseQueue(ABC):170"""Base Queue class that has a similar API to the Queue class in python."""171
172@abstractmethod173def __init__(self, name: str, default_timeout: float):174self.name = name175self.default_timeout = default_timeout176
177@abstractmethod178def put(self, item: Any) -> None:179pass180
181@abstractmethod182def get(self, timeout: Optional[float] = None) -> Any:183"""Returns the left most element of the queue.184
185Parameters
186----------
187timeout:
188Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
189A timeout of None can be used to block indefinitely.
190
191"""
192pass193
194@abstractmethod195def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:196"""Returns the left most elements of the queue.197
198Parameters
199----------
200timeout:
201Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
202A timeout of None can be used to block indefinitely.
203count:
204The number of element to get from the queue
205
206"""
207
208@property209def is_running(self) -> bool:210"""Returns True if the queue is running, False otherwise.211
212Child classes should override this property and implement custom logic as required
213
214"""
215return True216
217
218class MultiProcessQueue(BaseQueue):219def __init__(self, name: str, default_timeout: float) -> None:220self.name = name221self.default_timeout = default_timeout222context = multiprocessing.get_context("spawn")223self.queue = context.Queue()224
225def put(self, item: Any) -> None:226self.queue.put(item)227
228def get(self, timeout: Optional[float] = None) -> Any:229if timeout == 0:230timeout = self.default_timeout231return self.queue.get(timeout=timeout, block=(timeout is None))232
233def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:234if timeout == 0:235timeout = self.default_timeout236# For multiprocessing, we can simply collect the latest upmost element237return [self.queue.get(timeout=timeout, block=(timeout is None))]238
239
240class RedisQueue(BaseQueue):241@requires("redis")242def __init__(243self,244name: str,245default_timeout: float,246host: Optional[str] = None,247port: Optional[int] = None,248password: Optional[str] = None,249):250"""251Parameters
252----------
253name:
254The name of the list to use
255default_timeout:
256Default timeout for redis read
257host:
258The hostname of the redis server
259port:
260The port of the redis server
261password:
262Redis password
263"""
264if name is None:265raise ValueError("You must specify a name for the queue")266self.host = host or REDIS_HOST267self.port = port or REDIS_PORT268self.password = password or REDIS_PASSWORD269self.name = name270self.default_timeout = default_timeout271self.redis = redis.Redis(host=self.host, port=self.port, password=self.password)272
273def put(self, item: Any) -> None:274from lightning.app.core.work import LightningWork275
276is_work = isinstance(item, LightningWork)277
278# TODO: Be careful to handle with a lock if another thread needs279# to access the work backend one day.280# The backend isn't picklable281# Raises a TypeError: cannot pickle '_thread.RLock' object282if is_work:283backend = item._backend284item._backend = None285
286value = pickle.dumps(item)287queue_len = self.length()288if queue_len >= WARNING_QUEUE_SIZE:289warnings.warn(290f"The Redis Queue {self.name} length is larger than the "291f"recommended length of {WARNING_QUEUE_SIZE}. "292f"Found {queue_len}. This might cause your application to crash, "293"please investigate this."294)295try:296self.redis.rpush(self.name, value)297except redis.exceptions.ConnectionError:298raise ConnectionError(299"Your app failed because it couldn't connect to Redis. "300"Please try running your app again. "301"If the issue persists, please contact support@lightning.ai"302)303
304# The backend isn't pickable.305if is_work:306item._backend = backend307
308def get(self, timeout: Optional[float] = None) -> Any:309"""Returns the left most element of the redis queue.310
311Parameters
312----------
313timeout:
314Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
315A timeout of None can be used to block indefinitely.
316
317"""
318if timeout is None:319# this means it's blocking in redis320timeout = 0321elif timeout == 0:322timeout = self.default_timeout323
324try:325out = self.redis.blpop([self.name], timeout=timeout)326except redis.exceptions.ConnectionError:327raise ConnectionError(328"Your app failed because it couldn't connect to Redis. "329"Please try running your app again. "330"If the issue persists, please contact support@lightning.ai"331)332
333if out is None:334raise queue.Empty335return pickle.loads(out[1])336
337def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:338return [self.get(timeout=timeout)]339
340def clear(self) -> None:341"""Clear all elements in the queue."""342self.redis.delete(self.name)343
344def length(self) -> int:345"""Returns the number of elements in the queue."""346try:347return self.redis.llen(self.name)348except redis.exceptions.ConnectionError:349raise ConnectionError(350"Your app failed because it couldn't connect to Redis. "351"Please try running your app again. "352"If the issue persists, please contact support@lightning.ai"353)354
355@property356def is_running(self) -> bool:357"""Pinging the redis server to see if it is alive."""358try:359return self.redis.ping()360except redis.exceptions.ConnectionError:361return False362
363def to_dict(self) -> dict:364return {365"type": "redis",366"name": self.name,367"default_timeout": self.default_timeout,368"host": self.host,369"port": self.port,370"password": self.password,371}372
373@classmethod374def from_dict(cls, state: dict) -> "RedisQueue":375return cls(**state)376
377
378class RateLimitedQueue(BaseQueue):379def __init__(self, queue: BaseQueue, requests_per_second: float):380"""This is a queue wrapper that will block on get or put calls if they are made too quickly.381
382Args:
383queue: The queue to wrap.
384requests_per_second: The target number of get or put requests per second.
385
386"""
387self.name = queue.name388self.default_timeout = queue.default_timeout389
390self._queue = queue391self._seconds_per_request = 1 / requests_per_second392
393self._last_get = 0.0394
395@property396def is_running(self) -> bool:397return self._queue.is_running398
399def _wait_until_allowed(self, last_time: float) -> None:400t = time.time()401diff = t - last_time402if diff < self._seconds_per_request:403time.sleep(self._seconds_per_request - diff)404
405def get(self, timeout: Optional[float] = None) -> Any:406self._wait_until_allowed(self._last_get)407self._last_get = time.time()408return self._queue.get(timeout=timeout)409
410def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:411self._wait_until_allowed(self._last_get)412self._last_get = time.time()413return self._queue.batch_get(timeout=timeout)414
415def put(self, item: Any) -> None:416return self._queue.put(item)417
418
419class HTTPQueue(BaseQueue):420def __init__(self, name: str, default_timeout: float) -> None:421"""422Parameters
423----------
424name:
425The name of the Queue to use. In the current implementation, we expect the name to be of the format
426`appID_queueName`. Based on this assumption, we try to fetch the app id and the queue name by splitting
427the `name` argument.
428default_timeout:
429Default timeout for redis read
430"""
431if name is None:432raise ValueError("You must specify a name for the queue")433self.app_id, self._name_suffix = self._split_app_id_and_queue_name(name)434self.name = name # keeping the name for debugging435self.default_timeout = default_timeout436self.client = HTTPClient(base_url=HTTP_QUEUE_URL, auth_token=HTTP_QUEUE_TOKEN, log_callback=debug_log_callback)437
438@property439def is_running(self) -> bool:440"""Pinging the http redis server to see if it is alive."""441try:442url = urljoin(HTTP_QUEUE_URL, "health")443resp = requests.get(444url,445headers={"Authorization": f"Bearer {HTTP_QUEUE_TOKEN}"},446timeout=1,447)448if resp.status_code == 200:449return True450except (ConnectionError, ConnectTimeout, ReadTimeout):451return False452return False453
454def get(self, timeout: Optional[float] = None) -> Any:455if not self.app_id:456raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")457
458# it's a blocking call, we need to loop and call the backend to mimic this behavior459if timeout is None:460while True:461try:462try:463return self._get()464except requests.exceptions.HTTPError:465pass466except queue.Empty:467time.sleep(HTTP_QUEUE_REFRESH_INTERVAL)468
469# make one request and return the result470if timeout == 0:471try:472return self._get()473except requests.exceptions.HTTPError:474return None475
476# timeout is some value - loop until the timeout is reached477start_time = time.time()478while (time.time() - start_time) < timeout:479try:480try:481return self._get()482except requests.exceptions.HTTPError:483if timeout > self.default_timeout:484return None485raise queue.Empty486except queue.Empty:487# Note: In theory, there isn't a need for a sleep as the queue shouldn't488# block the flow if the queue is empty.489# However, as the Http Server can saturate,490# let's add a sleep here if a higher timeout is provided491# than the default timeout492if timeout > self.default_timeout:493time.sleep(0.05)494return None495
496def _get(self) -> Any:497try:498resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "pop"})499if resp.status_code == 204:500raise queue.Empty501return pickle.loads(resp.content)502except ConnectionError:503# Note: If the Http Queue service isn't available,504# we consider the queue is empty to avoid failing the app.505raise queue.Empty506
507def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:508try:509resp = self.client.post(510f"v1/{self.app_id}/{self._name_suffix}",511query_params={"action": "popCount", "count": str(count or BATCH_DELTA_COUNT)},512)513if resp.status_code == 204:514raise queue.Empty515return [pickle.loads(base64.b64decode(data)) for data in resp.json()]516except ConnectionError:517# Note: If the Http Queue service isn't available,518# we consider the queue is empty to avoid failing the app.519raise queue.Empty520
521@backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError))522def put(self, item: Any) -> None:523if not self.app_id:524raise ValueError(f"The Lightning App ID couldn't be extracted from the queue name: {self.name}")525
526value = pickle.dumps(item)527queue_len = self.length()528if queue_len >= WARNING_QUEUE_SIZE:529warnings.warn(530f"The Queue {self._name_suffix} length is larger than the recommended length of {WARNING_QUEUE_SIZE}. "531f"Found {queue_len}. This might cause your application to crash, please investigate this."532)533resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", data=value, query_params={"action": "push"})534if resp.status_code != 201:535raise RuntimeError(f"Failed to push to queue: {self._name_suffix}")536
537def length(self) -> int:538if not self.app_id:539raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")540
541try:542val = self.client.get(f"/v1/{self.app_id}/{self._name_suffix}/length")543return int(val.text)544except requests.exceptions.HTTPError:545return 0546
547@staticmethod548def _split_app_id_and_queue_name(queue_name: str) -> Tuple[str, str]:549"""This splits the app id and the queue name into two parts.550
551This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be
552accurate. Remove this eventually and let the Queue class take app id and name of the queue as arguments
553
554"""
555if "_" not in queue_name:556return "", queue_name557app_id, queue_name = queue_name.split("_", 1)558return app_id, queue_name559
560def to_dict(self) -> dict:561return {562"type": "http",563"name": self.name,564"default_timeout": self.default_timeout,565}566
567@classmethod568def from_dict(cls, state: dict) -> "HTTPQueue":569return cls(**state)570
571
572def debug_log_callback(message: str, *args: Any, **kwargs: Any) -> None:573if QUEUE_DEBUG_ENABLED or (Path(LIGHTNING_DIR) / "QUEUE_DEBUG_ENABLED").exists():574logger.info(message, *args, **kwargs)575