pytorch-lightning

Форк
0
574 строки · 21.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import base64
16
import multiprocessing
17
import pickle
18
import queue  # needed as import instead from/import for mocking in tests
19
import time
20
import warnings
21
from abc import ABC, abstractmethod
22
from enum import Enum
23
from pathlib import Path
24
from typing import Any, List, Optional, Tuple
25
from urllib.parse import urljoin
26

27
import backoff
28
import requests
29
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
30

31
from lightning.app.core.constants import (
32
    BATCH_DELTA_COUNT,
33
    HTTP_QUEUE_REFRESH_INTERVAL,
34
    HTTP_QUEUE_REQUESTS_PER_SECOND,
35
    HTTP_QUEUE_TOKEN,
36
    HTTP_QUEUE_URL,
37
    LIGHTNING_DIR,
38
    QUEUE_DEBUG_ENABLED,
39
    REDIS_HOST,
40
    REDIS_PASSWORD,
41
    REDIS_PORT,
42
    REDIS_QUEUES_READ_DEFAULT_TIMEOUT,
43
    STATE_UPDATE_TIMEOUT,
44
    WARNING_QUEUE_SIZE,
45
)
46
from lightning.app.utilities.app_helpers import Logger
47
from lightning.app.utilities.imports import _is_redis_available, requires
48
from lightning.app.utilities.network import HTTPClient
49

50
if _is_redis_available():
51
    import redis
52

53
logger = Logger(__name__)
54

55

56
READINESS_QUEUE_CONSTANT = "READINESS_QUEUE"
57
ERROR_QUEUE_CONSTANT = "ERROR_QUEUE"
58
DELTA_QUEUE_CONSTANT = "DELTA_QUEUE"
59
HAS_SERVER_STARTED_CONSTANT = "HAS_SERVER_STARTED_QUEUE"
60
CALLER_QUEUE_CONSTANT = "CALLER_QUEUE"
61
API_STATE_PUBLISH_QUEUE_CONSTANT = "API_STATE_PUBLISH_QUEUE"
62
API_DELTA_QUEUE_CONSTANT = "API_DELTA_QUEUE"
63
API_REFRESH_QUEUE_CONSTANT = "API_REFRESH_QUEUE"
64
ORCHESTRATOR_REQUEST_CONSTANT = "ORCHESTRATOR_REQUEST"
65
ORCHESTRATOR_RESPONSE_CONSTANT = "ORCHESTRATOR_RESPONSE"
66
ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST"
67
ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE"
68
WORK_QUEUE_CONSTANT = "WORK_QUEUE"
69
API_RESPONSE_QUEUE_CONSTANT = "API_RESPONSE_QUEUE"
70
FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT = "FLOW_TO_WORKS_DELTA_QUEUE"
71

72

73
class QueuingSystem(Enum):
74
    MULTIPROCESS = "multiprocess"
75
    REDIS = "redis"
76
    HTTP = "http"
77

78
    def get_queue(self, queue_name: str) -> "BaseQueue":
79
        if self == QueuingSystem.MULTIPROCESS:
80
            return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
81
        if self == QueuingSystem.REDIS:
82
            return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
83
        return RateLimitedQueue(
84
            HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
85
        )
86

87
    def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
88
        queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
89
        return self.get_queue(queue_name)
90

91
    def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
92
        queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT
93
        return self.get_queue(queue_name)
94

95
    def get_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
96
        queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT
97
        return self.get_queue(queue_name)
98

99
    def get_error_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
100
        queue_name = f"{queue_id}_{ERROR_QUEUE_CONSTANT}" if queue_id else ERROR_QUEUE_CONSTANT
101
        return self.get_queue(queue_name)
102

103
    def get_has_server_started_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
104
        queue_name = f"{queue_id}_{HAS_SERVER_STARTED_CONSTANT}" if queue_id else HAS_SERVER_STARTED_CONSTANT
105
        return self.get_queue(queue_name)
106

107
    def get_caller_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
108
        queue_name = (
109
            f"{queue_id}_{CALLER_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{CALLER_QUEUE_CONSTANT}_{work_name}"
110
        )
111
        return self.get_queue(queue_name)
112

113
    def get_api_state_publish_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
114
        queue_name = f"{queue_id}_{API_STATE_PUBLISH_QUEUE_CONSTANT}" if queue_id else API_STATE_PUBLISH_QUEUE_CONSTANT
115
        return self.get_queue(queue_name)
116

117
    # TODO: This is hack, so we can remove this queue entirely when fully optimized.
118
    def get_api_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
119
        queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT
120
        return self.get_queue(queue_name)
121

122
    def get_orchestrator_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
123
        queue_name = (
124
            f"{queue_id}_{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"
125
            if queue_id
126
            else f"{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"
127
        )
128
        return self.get_queue(queue_name)
129

130
    def get_orchestrator_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
131
        queue_name = (
132
            f"{queue_id}_{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"
133
            if queue_id
134
            else f"{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"
135
        )
136
        return self.get_queue(queue_name)
137

138
    def get_orchestrator_copy_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
139
        queue_name = (
140
            f"{queue_id}_{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"
141
            if queue_id
142
            else f"{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"
143
        )
144
        return self.get_queue(queue_name)
145

146
    def get_orchestrator_copy_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
147
        queue_name = (
148
            f"{queue_id}_{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"
149
            if queue_id
150
            else f"{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"
151
        )
152
        return self.get_queue(queue_name)
153

154
    def get_work_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
155
        queue_name = (
156
            f"{queue_id}_{WORK_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{WORK_QUEUE_CONSTANT}_{work_name}"
157
        )
158
        return self.get_queue(queue_name)
159

160
    def get_flow_to_work_delta_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
161
        queue_name = (
162
            f"{queue_id}_{FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT}_{work_name}"
163
            if queue_id
164
            else f"{FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT}_{work_name}"
165
        )
166
        return self.get_queue(queue_name)
167

168

169
class BaseQueue(ABC):
170
    """Base Queue class that has a similar API to the Queue class in python."""
171

172
    @abstractmethod
173
    def __init__(self, name: str, default_timeout: float):
174
        self.name = name
175
        self.default_timeout = default_timeout
176

177
    @abstractmethod
178
    def put(self, item: Any) -> None:
179
        pass
180

181
    @abstractmethod
182
    def get(self, timeout: Optional[float] = None) -> Any:
183
        """Returns the left most element of the queue.
184

185
        Parameters
186
        ----------
187
        timeout:
188
            Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
189
            A timeout of None can be used to block indefinitely.
190

191
        """
192
        pass
193

194
    @abstractmethod
195
    def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
196
        """Returns the left most elements of the queue.
197

198
        Parameters
199
        ----------
200
        timeout:
201
            Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
202
            A timeout of None can be used to block indefinitely.
203
        count:
204
            The number of element to get from the queue
205

206
        """
207

208
    @property
209
    def is_running(self) -> bool:
210
        """Returns True if the queue is running, False otherwise.
211

212
        Child classes should override this property and implement custom logic as required
213

214
        """
215
        return True
216

217

218
class MultiProcessQueue(BaseQueue):
219
    def __init__(self, name: str, default_timeout: float) -> None:
220
        self.name = name
221
        self.default_timeout = default_timeout
222
        context = multiprocessing.get_context("spawn")
223
        self.queue = context.Queue()
224

225
    def put(self, item: Any) -> None:
226
        self.queue.put(item)
227

228
    def get(self, timeout: Optional[float] = None) -> Any:
229
        if timeout == 0:
230
            timeout = self.default_timeout
231
        return self.queue.get(timeout=timeout, block=(timeout is None))
232

233
    def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
234
        if timeout == 0:
235
            timeout = self.default_timeout
236
        # For multiprocessing, we can simply collect the latest upmost element
237
        return [self.queue.get(timeout=timeout, block=(timeout is None))]
238

239

240
class RedisQueue(BaseQueue):
241
    @requires("redis")
242
    def __init__(
243
        self,
244
        name: str,
245
        default_timeout: float,
246
        host: Optional[str] = None,
247
        port: Optional[int] = None,
248
        password: Optional[str] = None,
249
    ):
250
        """
251
        Parameters
252
        ----------
253
        name:
254
            The name of the list to use
255
        default_timeout:
256
            Default timeout for redis read
257
        host:
258
            The hostname of the redis server
259
        port:
260
            The port of the redis server
261
        password:
262
            Redis password
263
        """
264
        if name is None:
265
            raise ValueError("You must specify a name for the queue")
266
        self.host = host or REDIS_HOST
267
        self.port = port or REDIS_PORT
268
        self.password = password or REDIS_PASSWORD
269
        self.name = name
270
        self.default_timeout = default_timeout
271
        self.redis = redis.Redis(host=self.host, port=self.port, password=self.password)
272

273
    def put(self, item: Any) -> None:
274
        from lightning.app.core.work import LightningWork
275

276
        is_work = isinstance(item, LightningWork)
277

278
        # TODO: Be careful to handle with a lock if another thread needs
279
        # to access the work backend one day.
280
        # The backend isn't picklable
281
        # Raises a TypeError: cannot pickle '_thread.RLock' object
282
        if is_work:
283
            backend = item._backend
284
            item._backend = None
285

286
        value = pickle.dumps(item)
287
        queue_len = self.length()
288
        if queue_len >= WARNING_QUEUE_SIZE:
289
            warnings.warn(
290
                f"The Redis Queue {self.name} length is larger than the "
291
                f"recommended length of {WARNING_QUEUE_SIZE}. "
292
                f"Found {queue_len}. This might cause your application to crash, "
293
                "please investigate this."
294
            )
295
        try:
296
            self.redis.rpush(self.name, value)
297
        except redis.exceptions.ConnectionError:
298
            raise ConnectionError(
299
                "Your app failed because it couldn't connect to Redis. "
300
                "Please try running your app again. "
301
                "If the issue persists, please contact support@lightning.ai"
302
            )
303

304
        # The backend isn't pickable.
305
        if is_work:
306
            item._backend = backend
307

308
    def get(self, timeout: Optional[float] = None) -> Any:
309
        """Returns the left most element of the redis queue.
310

311
        Parameters
312
        ----------
313
        timeout:
314
            Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
315
            A timeout of None can be used to block indefinitely.
316

317
        """
318
        if timeout is None:
319
            # this means it's blocking in redis
320
            timeout = 0
321
        elif timeout == 0:
322
            timeout = self.default_timeout
323

324
        try:
325
            out = self.redis.blpop([self.name], timeout=timeout)
326
        except redis.exceptions.ConnectionError:
327
            raise ConnectionError(
328
                "Your app failed because it couldn't connect to Redis. "
329
                "Please try running your app again. "
330
                "If the issue persists, please contact support@lightning.ai"
331
            )
332

333
        if out is None:
334
            raise queue.Empty
335
        return pickle.loads(out[1])
336

337
    def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
338
        return [self.get(timeout=timeout)]
339

340
    def clear(self) -> None:
341
        """Clear all elements in the queue."""
342
        self.redis.delete(self.name)
343

344
    def length(self) -> int:
345
        """Returns the number of elements in the queue."""
346
        try:
347
            return self.redis.llen(self.name)
348
        except redis.exceptions.ConnectionError:
349
            raise ConnectionError(
350
                "Your app failed because it couldn't connect to Redis. "
351
                "Please try running your app again. "
352
                "If the issue persists, please contact support@lightning.ai"
353
            )
354

355
    @property
356
    def is_running(self) -> bool:
357
        """Pinging the redis server to see if it is alive."""
358
        try:
359
            return self.redis.ping()
360
        except redis.exceptions.ConnectionError:
361
            return False
362

363
    def to_dict(self) -> dict:
364
        return {
365
            "type": "redis",
366
            "name": self.name,
367
            "default_timeout": self.default_timeout,
368
            "host": self.host,
369
            "port": self.port,
370
            "password": self.password,
371
        }
372

373
    @classmethod
374
    def from_dict(cls, state: dict) -> "RedisQueue":
375
        return cls(**state)
376

377

378
class RateLimitedQueue(BaseQueue):
379
    def __init__(self, queue: BaseQueue, requests_per_second: float):
380
        """This is a queue wrapper that will block on get or put calls if they are made too quickly.
381

382
        Args:
383
            queue: The queue to wrap.
384
            requests_per_second: The target number of get or put requests per second.
385

386
        """
387
        self.name = queue.name
388
        self.default_timeout = queue.default_timeout
389

390
        self._queue = queue
391
        self._seconds_per_request = 1 / requests_per_second
392

393
        self._last_get = 0.0
394

395
    @property
396
    def is_running(self) -> bool:
397
        return self._queue.is_running
398

399
    def _wait_until_allowed(self, last_time: float) -> None:
400
        t = time.time()
401
        diff = t - last_time
402
        if diff < self._seconds_per_request:
403
            time.sleep(self._seconds_per_request - diff)
404

405
    def get(self, timeout: Optional[float] = None) -> Any:
406
        self._wait_until_allowed(self._last_get)
407
        self._last_get = time.time()
408
        return self._queue.get(timeout=timeout)
409

410
    def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
411
        self._wait_until_allowed(self._last_get)
412
        self._last_get = time.time()
413
        return self._queue.batch_get(timeout=timeout)
414

415
    def put(self, item: Any) -> None:
416
        return self._queue.put(item)
417

418

419
class HTTPQueue(BaseQueue):
420
    def __init__(self, name: str, default_timeout: float) -> None:
421
        """
422
        Parameters
423
        ----------
424
        name:
425
            The name of the Queue to use. In the current implementation, we expect the name to be of the format
426
            `appID_queueName`. Based on this assumption, we try to fetch the app id and the queue name by splitting
427
            the `name` argument.
428
        default_timeout:
429
            Default timeout for redis read
430
        """
431
        if name is None:
432
            raise ValueError("You must specify a name for the queue")
433
        self.app_id, self._name_suffix = self._split_app_id_and_queue_name(name)
434
        self.name = name  # keeping the name for debugging
435
        self.default_timeout = default_timeout
436
        self.client = HTTPClient(base_url=HTTP_QUEUE_URL, auth_token=HTTP_QUEUE_TOKEN, log_callback=debug_log_callback)
437

438
    @property
439
    def is_running(self) -> bool:
440
        """Pinging the http redis server to see if it is alive."""
441
        try:
442
            url = urljoin(HTTP_QUEUE_URL, "health")
443
            resp = requests.get(
444
                url,
445
                headers={"Authorization": f"Bearer {HTTP_QUEUE_TOKEN}"},
446
                timeout=1,
447
            )
448
            if resp.status_code == 200:
449
                return True
450
        except (ConnectionError, ConnectTimeout, ReadTimeout):
451
            return False
452
        return False
453

454
    def get(self, timeout: Optional[float] = None) -> Any:
455
        if not self.app_id:
456
            raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")
457

458
        # it's a blocking call, we need to loop and call the backend to mimic this behavior
459
        if timeout is None:
460
            while True:
461
                try:
462
                    try:
463
                        return self._get()
464
                    except requests.exceptions.HTTPError:
465
                        pass
466
                except queue.Empty:
467
                    time.sleep(HTTP_QUEUE_REFRESH_INTERVAL)
468

469
        # make one request and return the result
470
        if timeout == 0:
471
            try:
472
                return self._get()
473
            except requests.exceptions.HTTPError:
474
                return None
475

476
        # timeout is some value - loop until the timeout is reached
477
        start_time = time.time()
478
        while (time.time() - start_time) < timeout:
479
            try:
480
                try:
481
                    return self._get()
482
                except requests.exceptions.HTTPError:
483
                    if timeout > self.default_timeout:
484
                        return None
485
                    raise queue.Empty
486
            except queue.Empty:
487
                # Note: In theory, there isn't a need for a sleep as the queue shouldn't
488
                # block the flow if the queue is empty.
489
                # However, as the Http Server can saturate,
490
                # let's add a sleep here if a higher timeout is provided
491
                # than the default timeout
492
                if timeout > self.default_timeout:
493
                    time.sleep(0.05)
494
        return None
495

496
    def _get(self) -> Any:
497
        try:
498
            resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "pop"})
499
            if resp.status_code == 204:
500
                raise queue.Empty
501
            return pickle.loads(resp.content)
502
        except ConnectionError:
503
            # Note: If the Http Queue service isn't available,
504
            # we consider the queue is empty to avoid failing the app.
505
            raise queue.Empty
506

507
    def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
508
        try:
509
            resp = self.client.post(
510
                f"v1/{self.app_id}/{self._name_suffix}",
511
                query_params={"action": "popCount", "count": str(count or BATCH_DELTA_COUNT)},
512
            )
513
            if resp.status_code == 204:
514
                raise queue.Empty
515
            return [pickle.loads(base64.b64decode(data)) for data in resp.json()]
516
        except ConnectionError:
517
            # Note: If the Http Queue service isn't available,
518
            # we consider the queue is empty to avoid failing the app.
519
            raise queue.Empty
520

521
    @backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError))
522
    def put(self, item: Any) -> None:
523
        if not self.app_id:
524
            raise ValueError(f"The Lightning App ID couldn't be extracted from the queue name: {self.name}")
525

526
        value = pickle.dumps(item)
527
        queue_len = self.length()
528
        if queue_len >= WARNING_QUEUE_SIZE:
529
            warnings.warn(
530
                f"The Queue {self._name_suffix} length is larger than the recommended length of {WARNING_QUEUE_SIZE}. "
531
                f"Found {queue_len}. This might cause your application to crash, please investigate this."
532
            )
533
        resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", data=value, query_params={"action": "push"})
534
        if resp.status_code != 201:
535
            raise RuntimeError(f"Failed to push to queue: {self._name_suffix}")
536

537
    def length(self) -> int:
538
        if not self.app_id:
539
            raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")
540

541
        try:
542
            val = self.client.get(f"/v1/{self.app_id}/{self._name_suffix}/length")
543
            return int(val.text)
544
        except requests.exceptions.HTTPError:
545
            return 0
546

547
    @staticmethod
548
    def _split_app_id_and_queue_name(queue_name: str) -> Tuple[str, str]:
549
        """This splits the app id and the queue name into two parts.
550

551
        This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be
552
        accurate. Remove this eventually and let the Queue class take app id and name of the queue as arguments
553

554
        """
555
        if "_" not in queue_name:
556
            return "", queue_name
557
        app_id, queue_name = queue_name.split("_", 1)
558
        return app_id, queue_name
559

560
    def to_dict(self) -> dict:
561
        return {
562
            "type": "http",
563
            "name": self.name,
564
            "default_timeout": self.default_timeout,
565
        }
566

567
    @classmethod
568
    def from_dict(cls, state: dict) -> "HTTPQueue":
569
        return cls(**state)
570

571

572
def debug_log_callback(message: str, *args: Any, **kwargs: Any) -> None:
573
    if QUEUE_DEBUG_ENABLED or (Path(LIGHTNING_DIR) / "QUEUE_DEBUG_ENABLED").exists():
574
        logger.info(message, *args, **kwargs)
575

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.