pytorch-lightning

Форк
0
753 строки · 29.7 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import asyncio
16
import logging
17
import time
18
import uuid
19
from itertools import cycle
20
from typing import Any, Dict, List, Optional, Tuple, Type, Union
21
from typing import SupportsFloat as Numeric
22

23
import requests
24
import uvicorn
25
from fastapi import FastAPI, HTTPException, Request
26
from fastapi.middleware.cors import CORSMiddleware
27
from fastapi.responses import RedirectResponse
28
from pydantic import BaseModel
29
from starlette.staticfiles import StaticFiles
30

31
from lightning.app.components.serve.cold_start_proxy import ColdStartProxy
32
from lightning.app.core.flow import LightningFlow
33
from lightning.app.core.work import LightningWork
34
from lightning.app.utilities.app_helpers import Logger
35
from lightning.app.utilities.cloud import is_running_in_cloud
36
from lightning.app.utilities.imports import _is_aiohttp_available, requires
37
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
38

39
if _is_aiohttp_available():
40
    import aiohttp
41
    import aiohttp.client_exceptions
42

43
logger = Logger(__name__)
44

45

46
class _TrackableFastAPI(FastAPI):
47
    """A FastAPI subclass that tracks the request metadata."""
48

49
    def __init__(self, *args: Any, **kwargs: Any):
50
        super().__init__(*args, **kwargs)
51
        self.global_request_count = 0
52
        self.num_current_requests = 0
53
        self.last_processing_time = 0
54

55

56
def _maybe_raise_granular_exception(exception: Exception) -> None:
57
    """Handle an exception from hitting the model servers."""
58
    if not isinstance(exception, Exception):
59
        return
60

61
    if isinstance(exception, HTTPException):
62
        raise exception
63

64
    if isinstance(exception, aiohttp.client_exceptions.ServerDisconnectedError):
65
        raise HTTPException(500, "Worker Server Disconnected") from exception
66

67
    if isinstance(exception, aiohttp.client_exceptions.ClientError):
68
        logging.exception(exception)
69
        raise HTTPException(500, "Worker Server error") from exception
70

71
    if isinstance(exception, asyncio.TimeoutError):
72
        raise HTTPException(408, "Request timed out") from exception
73

74
    if isinstance(exception, Exception) and exception.args[0] == "Server disconnected":
75
        raise HTTPException(500, "Worker Server disconnected") from exception
76

77
    logging.exception(exception)
78
    raise HTTPException(500, exception.args[0]) from exception
79

80

81
class _SysInfo(BaseModel):
82
    num_workers: int
83
    servers: List[str]
84
    num_requests: int
85
    processing_time: int
86
    global_request_count: int
87

88

89
class _BatchRequestModel(BaseModel):
90
    inputs: List[Any]
91

92

93
def _create_fastapi(title: str) -> _TrackableFastAPI:
94
    fastapi_app = _TrackableFastAPI(title=title)
95

96
    fastapi_app.add_middleware(
97
        CORSMiddleware,
98
        allow_origins=["*"],
99
        allow_credentials=True,
100
        allow_methods=["*"],
101
        allow_headers=["*"],
102
    )
103

104
    @fastapi_app.get("/", include_in_schema=False)
105
    async def docs():
106
        return RedirectResponse("/docs")
107

108
    @fastapi_app.get("/num-requests")
109
    async def num_requests() -> int:
110
        return fastapi_app.num_current_requests
111

112
    return fastapi_app
113

114

115
class _LoadBalancer(LightningWork):
116
    r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
117
    asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
118

119
    After enabling you will require to send username and password from the request header for the private endpoints.
120

121
    Args:
122
        input_type: Input type.
123
        output_type: Output type.
124
        endpoint: The REST API path.
125
        max_batch_size: The number of requests processed at once.
126
        timeout_batching: The number of seconds to wait before sending the requests to process in order to allow for
127
            requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
128
        timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
129
        timeout_inference_request: The number of seconds to wait for inference.
130
        api_name: The name to be displayed on the UI. Normally, it is the name of the work class
131
        cold_start_proxy: The proxy service to use while the work is cold starting.
132
        **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
133

134
    """
135

136
    @requires(["aiohttp"])
137
    def __init__(
138
        self,
139
        input_type: Type[BaseModel],
140
        output_type: Type[BaseModel],
141
        endpoint: str,
142
        max_batch_size: int = 8,
143
        # all timeout args are in seconds
144
        timeout_batching: float = 1,
145
        timeout_keep_alive: int = 60,
146
        timeout_inference_request: int = 60,
147
        api_name: Optional[str] = "API",  # used for displaying the name in the UI
148
        cold_start_proxy: Union[ColdStartProxy, str, None] = None,
149
        **kwargs: Any,
150
    ) -> None:
151
        super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
152
        self._input_type = input_type
153
        self._output_type = output_type
154
        self._timeout_keep_alive = timeout_keep_alive
155
        self._timeout_inference_request = timeout_inference_request
156
        self.servers = []
157
        self.max_batch_size = max_batch_size
158
        self.timeout_batching = timeout_batching
159
        self._iter = None
160
        self._batch = []
161
        self._responses = {}  # {request_id: response}
162
        self._last_batch_sent = None
163
        self._server_status = {}
164
        self._api_name = api_name
165
        self.ready = False
166

167
        if not endpoint.startswith("/"):
168
            endpoint = "/" + endpoint
169

170
        self.endpoint = endpoint
171
        self._fastapi_app = None
172

173
        self._cold_start_proxy = None
174
        if cold_start_proxy:
175
            if isinstance(cold_start_proxy, str):
176
                self._cold_start_proxy = ColdStartProxy(proxy_url=cold_start_proxy)
177
            elif isinstance(cold_start_proxy, ColdStartProxy):
178
                self._cold_start_proxy = cold_start_proxy
179
            else:
180
                raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
181

182
    def get_internal_url(self) -> str:
183
        if not self._public_ip:
184
            raise ValueError("Public IP not set")
185
        return f"http://{self._public_ip}:{self._port}"
186

187
    async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
188
        request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
189
        batch_request_data = _BatchRequestModel(inputs=request_data)
190

191
        try:
192
            self._server_status[server_url] = False
193
            async with aiohttp.ClientSession() as session:
194
                headers = {
195
                    "accept": "application/json",
196
                    "Content-Type": "application/json",
197
                }
198
                async with session.post(
199
                    f"{server_url}{self.endpoint}",
200
                    json=batch_request_data.dict(),
201
                    timeout=self._timeout_inference_request,
202
                    headers=headers,
203
                ) as response:
204
                    if response.status == 408:
205
                        raise HTTPException(408, "Request timed out")
206
                    response.raise_for_status()
207
                    response = await response.json()
208
                    outputs = response["outputs"]
209
                    if len(batch) != len(outputs):
210
                        raise RuntimeError(f"result has {len(outputs)} items but batch is {len(batch)}")
211
                    result = {request[0]: r for request, r in zip(batch, outputs)}
212
                    self._responses.update(result)
213
        except Exception as ex:
214
            result = {request[0]: ex for request in batch}
215
            self._responses.update(result)
216
        finally:
217
            # resetting the server status so other requests can be
218
            # scheduled on this node
219
            if server_url in self._server_status:
220
                # TODO - if the server returns an error, track that so
221
                #  we don't send more requests to it
222
                self._server_status[server_url] = True
223

224
    def _find_free_server(self) -> Optional[str]:
225
        existing = set(self._server_status.keys())
226
        for server in existing:
227
            status = self._server_status.get(server, None)
228
            if status is None:
229
                logger.error("Server is not found in the status list. This should not happen.")
230
            if status:
231
                return server
232
        return None
233

234
    async def consumer(self):
235
        """The consumer process that continuously checks for new requests and sends them to the API.
236

237
        Two instances of this function should not be running with shared `_state_server` as that would create race
238
        conditions
239

240
        """
241
        while True:
242
            await asyncio.sleep(0.05)
243
            batch = self._batch[: self.max_batch_size]
244
            is_batch_ready = len(batch) == self.max_batch_size
245
            if len(batch) > 0 and self._last_batch_sent is None:
246
                self._last_batch_sent = time.time()
247

248
            if self._last_batch_sent:
249
                is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching
250
            else:
251
                is_batch_timeout = False
252

253
            server_url = self._find_free_server()
254
            # setting the server status to be busy! This will be reset by
255
            # the send_batch function after the server responds
256
            if server_url is None:
257
                continue
258
            if batch and (is_batch_ready or is_batch_timeout):
259
                self._server_status[server_url] = False
260
                # find server with capacity
261
                # Saving a reference to the result of this function, protects the task disappearing mid-execution
262
                # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
263
                task = asyncio.create_task(self.send_batch(batch, server_url))  # noqa: F841
264
                # resetting the batch array, TODO - not locking the array
265
                self._batch = self._batch[len(batch) :]
266
                self._last_batch_sent = time.time()
267

268
    async def process_request(self, data: BaseModel, request_id=None):
269
        if request_id is None:
270
            request_id = uuid.uuid4().hex
271
        if not self.servers and not self._cold_start_proxy:
272
            # sleeping to trigger the scale up
273
            raise HTTPException(503, "None of the workers are healthy!, try again in a few seconds")
274

275
        # if no servers are available, proxy the request to cold start proxy handler
276
        if not self.servers and self._cold_start_proxy:
277
            return await self._cold_start_proxy.handle_request(data)
278

279
        # if out of capacity, proxy the request to cold start proxy handler
280
        if not self._has_processing_capacity() and self._cold_start_proxy:
281
            return await self._cold_start_proxy.handle_request(data)
282

283
        # if we have capacity, process the request
284
        self._batch.append((request_id, data))
285
        while True:
286
            await asyncio.sleep(0.05)
287
            if request_id in self._responses:
288
                result = self._responses[request_id]
289
                del self._responses[request_id]
290
                _maybe_raise_granular_exception(result)
291
                return result
292

293
    def _has_processing_capacity(self):
294
        """This function checks if we have processing capacity for one more request or not.
295

296
        Depends on the value from here, we decide whether we should proxy the request or not
297

298
        """
299
        if not self._fastapi_app:
300
            return False
301
        active_server_count = len(self.servers)
302
        max_processable = self.max_batch_size * active_server_count
303
        current_req_count = self._fastapi_app.num_current_requests
304
        return current_req_count < max_processable
305

306
    def run(self):
307
        logger.info(f"servers: {self.servers}")
308

309
        self._iter = cycle(self.servers)
310

311
        fastapi_app = _create_fastapi("Load Balancer")
312
        fastapi_app.SEND_TASK = None
313
        self._fastapi_app = fastapi_app
314

315
        input_type = self._input_type
316

317
        @fastapi_app.middleware("http")
318
        async def current_request_counter(request: Request, call_next):
319
            if request.scope["path"] != self.endpoint:
320
                return await call_next(request)
321
            fastapi_app.global_request_count += 1
322
            fastapi_app.num_current_requests += 1
323
            start_time = time.time()
324
            response = await call_next(request)
325
            processing_time = time.time() - start_time
326
            fastapi_app.last_processing_time = processing_time
327
            fastapi_app.num_current_requests -= 1
328
            return response
329

330
        @fastapi_app.on_event("startup")
331
        async def startup_event():
332
            fastapi_app.SEND_TASK = asyncio.create_task(self.consumer())
333

334
        @fastapi_app.on_event("shutdown")
335
        def shutdown_event():
336
            fastapi_app.SEND_TASK.cancel()
337

338
        @fastapi_app.get("/system/info", response_model=_SysInfo)
339
        async def sys_info():
340
            return _SysInfo(
341
                num_workers=len(self.servers),
342
                servers=self.servers,
343
                num_requests=fastapi_app.num_current_requests,
344
                processing_time=fastapi_app.last_processing_time,
345
                global_request_count=fastapi_app.global_request_count,
346
            )
347

348
        @fastapi_app.put("/system/update-servers")
349
        async def update_servers(servers: List[str]):
350
            self.servers = servers
351
            self._iter = cycle(self.servers)
352
            updated_servers = set()
353
            # do not try to loop over the dict keys as the dict might change from other places
354
            existing_servers = list(self._server_status.keys())
355
            for server in servers:
356
                updated_servers.add(server)
357
                if server not in existing_servers:
358
                    self._server_status[server] = True
359
                    logger.info(f"Registering server {server}", self._server_status)
360
            for existing in existing_servers:
361
                if existing not in updated_servers:
362
                    logger.info(f"De-Registering server {existing}", self._server_status)
363
                    del self._server_status[existing]
364

365
        @fastapi_app.post(self.endpoint, response_model=self._output_type)
366
        async def balance_api(inputs: input_type):
367
            return await self.process_request(inputs)
368

369
        endpoint_info_page = self._get_endpoint_info_page()
370
        if endpoint_info_page:
371
            fastapi_app.mount(
372
                "/endpoint-info", StaticFiles(directory=endpoint_info_page.serve_dir, html=True), name="static"
373
            )
374

375
        logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
376
        self.ready = True
377
        uvicorn.run(
378
            fastapi_app,
379
            host=self.host,
380
            port=self.port,
381
            loop="uvloop",
382
            timeout_keep_alive=self._timeout_keep_alive,
383
            access_log=False,
384
        )
385

386
    def update_servers(self, server_works: List[LightningWork]):
387
        """Updates works that load balancer distributes requests to.
388

389
        AutoScaler uses this method to increase/decrease the number of works.
390

391
        """
392
        old_server_urls = set(self.servers)
393
        current_server_urls = {
394
            f"http://{server._public_ip}:{server.port}" for server in server_works if server._internal_ip
395
        }
396

397
        # doing nothing if no server work has been added/removed
398
        if old_server_urls == current_server_urls:
399
            return
400

401
        # checking if the url is ready or not
402
        available_urls = set()
403
        for url in current_server_urls:
404
            try:
405
                _ = requests.get(url)
406
            except requests.exceptions.ConnectionError:
407
                continue
408
            else:
409
                available_urls.add(url)
410
        if old_server_urls == available_urls:
411
            return
412

413
        newly_added = available_urls - old_server_urls
414
        if newly_added:
415
            logger.info(f"servers added: {newly_added}")
416

417
        deleted = old_server_urls - available_urls
418
        if deleted:
419
            logger.info(f"servers deleted: {deleted}")
420
        self.send_request_to_update_servers(list(available_urls))
421

422
    def send_request_to_update_servers(self, servers: List[str]):
423
        try:
424
            internal_url = self.get_internal_url()
425
        except ValueError:
426
            logger.warn("Cannot update servers as internal_url is not set")
427
            return
428
        response = requests.put(f"{internal_url}/system/update-servers", json=servers, timeout=10)
429
        response.raise_for_status()
430

431
    @staticmethod
432
    def _get_sample_dict_from_datatype(datatype: Any) -> dict:
433
        if not hasattr(datatype, "schema"):
434
            # not a pydantic model
435
            raise TypeError(f"datatype must be a pydantic model, for the UI to be generated. but got {datatype}")
436

437
        if hasattr(datatype, "get_sample_data"):
438
            return datatype.get_sample_data()
439

440
        datatype_props = datatype.schema()["properties"]
441
        out: Dict[str, Any] = {}
442
        lut = {"string": "data string", "number": 0.0, "integer": 0, "boolean": False}
443
        for k, v in datatype_props.items():
444
            if v["type"] not in lut:
445
                raise TypeError("Unsupported type")
446
            out[k] = lut[v["type"]]
447
        return out
448

449
    def get_code_sample(self, url: str) -> Optional[str]:
450
        input_type: Any = self._input_type
451
        output_type: Any = self._output_type
452

453
        if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
454
            return None
455
        return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"
456

457
    def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]:  # noqa: F821
458
        try:
459
            from lightning_api_access import APIAccessFrontend
460
        except ModuleNotFoundError:
461
            logger.warn(
462
                "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`"
463
            )
464
            return None
465

466
        if is_running_in_cloud():
467
            url = f"{self._future_url}{self.endpoint}"
468
        else:
469
            url = f"http://localhost:{self.port}{self.endpoint}"
470

471
        frontend_objects = {"name": self._api_name, "url": url, "method": "POST", "request": None, "response": None}
472
        code_samples = self.get_code_sample(url)
473
        if code_samples:
474
            frontend_objects["code_sample"] = code_samples
475
            # TODO also set request/response for JS UI
476
        else:
477
            try:
478
                request = self._get_sample_dict_from_datatype(self._input_type)
479
                response = self._get_sample_dict_from_datatype(self._output_type)
480
            except TypeError:
481
                return None
482
            else:
483
                frontend_objects["request"] = request
484
                frontend_objects["response"] = response
485
        return APIAccessFrontend(apis=[frontend_objects])
486

487

488
class AutoScaler(LightningFlow):
489
    """The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in response to
490
    changes in the number of incoming requests. Incoming requests will be batched and balanced across the replicas.
491

492
    Args:
493
        min_replicas: The number of works to start when app initializes.
494
        max_replicas: The max number of works to spawn to handle the incoming requests.
495
        scale_out_interval: The number of seconds to wait before checking whether to increase the number of servers.
496
        scale_in_interval: The number of seconds to wait before checking whether to decrease the number of servers.
497
        endpoint: Provide the REST API path.
498
        max_batch_size: (auto-batching) The number of requests to process at once.
499
        timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process.
500
        input_type: Input type.
501
        output_type: Output type.
502
        cold_start_proxy: If provided, the proxy will be used while the worker machines are warming up.
503

504
    .. testcode::
505

506
        from lightning.app import LightningApp
507
        from lightning.app.components import AutoScaler
508

509
        # Example 1: Auto-scaling serve component out-of-the-box
510
        app = LightningApp(
511
            app.components.AutoScaler(
512
                MyPythonServer,
513
                min_replicas=1,
514
                max_replicas=8,
515
                scale_out_interval=10,
516
                scale_in_interval=10,
517
            )
518
        )
519

520

521
        # Example 2: Customizing the scaling logic
522
        class MyAutoScaler(AutoScaler):
523
            def scale(self, replicas: int, metrics: dict) -> int:
524
                pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
525
                    replicas + metrics["pending_works"]
526
                )
527

528
                # upscale
529
                max_requests_per_work = self.max_batch_size
530
                if pending_requests_per_running_or_pending_work >= max_requests_per_work:
531
                    return replicas + 1
532

533
                # downscale
534
                min_requests_per_work = max_requests_per_work * 0.25
535
                if pending_requests_per_running_or_pending_work < min_requests_per_work:
536
                    return replicas - 1
537

538
                return replicas
539

540

541
        app = LightningApp(
542
            MyAutoScaler(
543
                MyPythonServer,
544
                min_replicas=1,
545
                max_replicas=8,
546
                scale_out_interval=10,
547
                scale_in_interval=10,
548
                max_batch_size=8,  # for auto batching
549
                timeout_batching=1,  # for auto batching
550
            )
551
        )
552

553
    """
554

555
    def __init__(
556
        self,
557
        work_cls: Type[LightningWork],
558
        min_replicas: int = 1,
559
        max_replicas: int = 4,
560
        scale_out_interval: Numeric = 10,
561
        scale_in_interval: Numeric = 10,
562
        max_batch_size: int = 8,
563
        timeout_batching: float = 1,
564
        endpoint: str = "api/predict",
565
        input_type: Type[BaseModel] = Dict,
566
        output_type: Type[BaseModel] = Dict,
567
        cold_start_proxy: Union[ColdStartProxy, str, None] = None,
568
        *work_args: Any,
569
        **work_kwargs: Any,
570
    ) -> None:
571
        super().__init__()
572
        self.num_replicas = 0
573
        self._work_registry = {}
574

575
        self._work_cls = work_cls
576
        self._work_args = work_args
577
        self._work_kwargs = work_kwargs
578

579
        self._input_type = input_type
580
        self._output_type = output_type
581
        self.scale_out_interval = scale_out_interval
582
        self.scale_in_interval = scale_in_interval
583
        self.max_batch_size = max_batch_size
584

585
        if max_replicas < min_replicas:
586
            raise ValueError(
587
                f"`max_replicas={max_replicas}` must be less than or equal to `min_replicas={min_replicas}`."
588
            )
589
        self.max_replicas = max_replicas
590
        self.min_replicas = min_replicas
591
        self._last_autoscale = time.time()
592
        self.fake_trigger = 0
593

594
        self.load_balancer = _LoadBalancer(
595
            input_type=self._input_type,
596
            output_type=self._output_type,
597
            endpoint=endpoint,
598
            max_batch_size=max_batch_size,
599
            timeout_batching=timeout_batching,
600
            cache_calls=True,
601
            parallel=True,
602
            api_name=self._work_cls.__name__,
603
            cold_start_proxy=cold_start_proxy,
604
        )
605

606
    @property
607
    def ready(self) -> bool:
608
        return self.load_balancer.ready
609

610
    @property
611
    def workers(self) -> List[LightningWork]:
612
        return [self.get_work(i) for i in range(self.num_replicas)]
613

614
    def create_work(self) -> LightningWork:
615
        """Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
616
        cloud_compute = self._work_kwargs.get("cloud_compute", None)
617
        self._work_kwargs.update({
618
            "start_with_flow": False,
619
            "cloud_compute": cloud_compute.clone() if cloud_compute else None,
620
        })
621
        return self._work_cls(*self._work_args, **self._work_kwargs)
622

623
    def add_work(self, work) -> str:
624
        """Adds a new LightningWork instance.
625

626
        Returns:
627
            The name of the new work attribute.
628

629
        """
630
        work_attribute = uuid.uuid4().hex
631
        work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}"
632
        setattr(self, work_attribute, work)
633
        self._work_registry[self.num_replicas] = work_attribute
634
        self.num_replicas += 1
635
        return work_attribute
636

637
    def remove_work(self, index: int) -> str:
638
        """Removes the ``index`` th LightningWork instance."""
639
        work_attribute = self._work_registry[index]
640
        del self._work_registry[index]
641
        work = getattr(self, work_attribute)
642
        work.stop()
643
        self.num_replicas -= 1
644
        return work_attribute
645

646
    def get_work(self, index: int) -> LightningWork:
647
        """Returns the ``LightningWork`` instance with the given index."""
648
        work_attribute = self._work_registry[index]
649
        return getattr(self, work_attribute)
650

651
    def run(self):
652
        if not self.load_balancer.is_running:
653
            self.load_balancer.run()
654
        for work in self.workers:
655
            work.run()
656
        if self.load_balancer.url:
657
            self.fake_trigger += 1  # Note: change state to keep calling `run`.
658
            self.autoscale()
659

660
    def scale(self, replicas: int, metrics: dict) -> int:
661
        """The default scaling logic that users can override.
662

663
        Args:
664
            replicas: The number of running works.
665
            metrics: ``metrics['pending_requests']`` is the total number of requests that are currently pending.
666
                ``metrics['pending_works']`` is the number of pending works.
667

668
        Returns:
669
            The target number of running works. The value will be adjusted after this method runs
670
            so that it satisfies ``min_replicas<=replicas<=max_replicas``.
671

672
        """
673
        pending_requests = metrics["pending_requests"]
674
        active_or_pending_works = replicas + metrics["pending_works"]
675

676
        if active_or_pending_works == 0:
677
            return 1 if pending_requests > 0 else 0
678

679
        pending_requests_per_running_or_pending_work = pending_requests / active_or_pending_works
680

681
        # scale out if the number of pending requests exceeds max batch size.
682
        max_requests_per_work = self.max_batch_size
683
        if pending_requests_per_running_or_pending_work >= max_requests_per_work:
684
            return replicas + 1
685

686
        # scale in if the number of pending requests is below 25% of max_requests_per_work
687
        min_requests_per_work = max_requests_per_work * 0.25
688
        if pending_requests_per_running_or_pending_work < min_requests_per_work:
689
            return replicas - 1
690

691
        return replicas
692

693
    @property
694
    def num_pending_requests(self) -> int:
695
        """Fetches the number of pending requests via load balancer."""
696
        try:
697
            load_balancer_url = self.load_balancer.get_internal_url()
698
        except ValueError:
699
            logger.warn("Cannot update servers as internal_url is not set")
700
            return 0
701
        return int(requests.get(f"{load_balancer_url}/num-requests").json())
702

703
    @property
704
    def num_pending_works(self) -> int:
705
        """The number of pending works."""
706
        return sum(work.is_pending for work in self.workers)
707

708
    def autoscale(self) -> None:
709
        """Adjust the number of works based on the target number returned by ``self.scale``."""
710
        metrics = {
711
            "pending_requests": self.num_pending_requests,
712
            "pending_works": self.num_pending_works,
713
        }
714

715
        # ensure min_replicas <= num_replicas <= max_replicas
716
        num_target_workers = max(
717
            self.min_replicas,
718
            min(self.max_replicas, self.scale(self.num_replicas, metrics)),
719
        )
720

721
        # scale-out
722
        if time.time() - self._last_autoscale > self.scale_out_interval:
723
            # TODO figuring out number of workers to add only based on num_replicas isn't right because pending works
724
            #  are not added to num_replicas
725
            num_workers_to_add = num_target_workers - self.num_replicas
726
            for _ in range(num_workers_to_add):
727
                logger.info(f"Scaling out from {self.num_replicas} to {self.num_replicas + 1}")
728
                work = self.create_work()
729
                # TODO: move works into structures
730
                new_work_id = self.add_work(work)
731
                logger.info(f"Work created: '{new_work_id}'")
732
            if num_workers_to_add > 0:
733
                self._last_autoscale = time.time()
734

735
        # scale-in
736
        if time.time() - self._last_autoscale > self.scale_in_interval:
737
            # TODO figuring out number of workers to remove only based on num_replicas isn't right because pending works
738
            #  are not added to num_replicas
739
            num_workers_to_remove = self.num_replicas - num_target_workers
740
            for _ in range(num_workers_to_remove):
741
                logger.info(f"Scaling in from {self.num_replicas} to {self.num_replicas - 1}")
742
                removed_work_id = self.remove_work(self.num_replicas - 1)
743
                logger.info(f"Work removed: '{removed_work_id}'")
744
            if num_workers_to_remove > 0:
745
                self._last_autoscale = time.time()
746

747
        self.load_balancer.update_servers(self.workers)
748

749
    def configure_layout(self):
750
        return [
751
            {"name": "Endpoint Info", "content": f"{self.load_balancer.url}/endpoint-info"},
752
            {"name": "Swagger", "content": self.load_balancer.url},
753
        ]
754

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.