ray-llm

Форк
0
/
vllm_deployment.py 
91 строка · 2.6 Кб
1
import logging
2
import os
3

4
from ray import serve
5

6
from rayllm.backend.llm.vllm.vllm_engine import VLLMEngine
7
from rayllm.backend.llm.vllm.vllm_models import (
8
    VLLMApp,
9
    VLLMGenerationRequest,
10
    VLLMSamplingParams,
11
)
12
from rayllm.backend.server.models import (
13
    QueuePriority,
14
    SchedulingMetadata,
15
)
16
from rayllm.common.models import Prompt
17

18
logger = logging.getLogger(__name__)
19

20

21
class VLLMDeploymentImpl:
22
    _generation_request_cls = VLLMGenerationRequest
23
    _default_engine_cls = VLLMEngine
24

25
    async def __init__(
26
        self, base_config: VLLMApp, *, engine_cls=None, generation_request_cls=None
27
    ):
28
        self.base_config = base_config
29
        self.config_store = {}  # type: ignore
30

31
        engine_cls = engine_cls or self._default_engine_cls
32
        self._generation_request_cls = (
33
            generation_request_cls or self._generation_request_cls
34
        )
35

36
        self.engine = engine_cls(base_config)
37
        await self.engine.start()
38

39
    async def stream(
40
        self,
41
        request_id: str,
42
        prompt: Prompt,
43
        priority=QueuePriority.GENERATE_TEXT,
44
    ):
45
        """A thin wrapper around VLLMEngine.generate().
46
        1. Load the model to disk
47
        2. Format parameters correctly
48
        3. Forward request to VLLMEngine.generate()
49
        """
50

51
        prompt_text = (
52
            self.base_config.engine_config.generation.prompt_format.generate_prompt(
53
                prompt
54
            )
55
        )
56
        sampling_params = VLLMSamplingParams.merge_generation_params(
57
            prompt, self.base_config.engine_config.generation
58
        )
59

60
        logger.info(f"Received streaming request {request_id}")
61
        vllm_request = self._generation_request_cls(
62
            prompt=prompt_text,
63
            request_id=request_id,
64
            sampling_params=sampling_params,
65
            scheduling_metadata=SchedulingMetadata(
66
                request_id=request_id, priority=priority
67
            ),
68
        )
69
        async for aviary_model_response in self.engine.generate(vllm_request):
70
            yield aviary_model_response
71

72
    async def check_health(self):
73
        return await self.engine.check_health()
74

75

76
@serve.deployment(
77
    # TODO make this configurable in aviary run
78
    autoscaling_config={
79
        "min_replicas": 1,
80
        "initial_replicas": 1,
81
        "max_replicas": 10,
82
        "target_num_ongoing_requests_per_replica": int(
83
            os.environ.get("AVIARY_ROUTER_TARGET_NUM_ONGOING_REQUESTS_PER_REPLICA", 10)
84
        ),
85
    },
86
    max_concurrent_queries=20,  # Maximum backlog for a single replica
87
    health_check_period_s=30,
88
    health_check_timeout_s=30,
89
)
90
class VLLMDeployment(VLLMDeploymentImpl):
91
    ...
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.