ray-llm

Форк
0
/
embedding_deployment.py 
112 строк · 3.6 Кб
1
import logging
2
import os
3
from typing import List, Optional, Union
4

5
from ray import serve
6

7
from rayllm.backend.llm.embedding.embedding_engine import EmbeddingEngine
8
from rayllm.backend.llm.embedding.embedding_models import EmbeddingApp
9
from rayllm.backend.server.models import (
10
    AviaryModelResponse,
11
    GenerationRequest,
12
    QueuePriority,
13
    SchedulingMetadata,
14
)
15
from rayllm.backend.server.utils import get_response_for_error
16
from rayllm.common.models import Prompt
17

18
logger = logging.getLogger(__name__)
19

20

21
class EmbeddingDeploymentImpl:
22
    _generation_request_cls = GenerationRequest
23
    _default_engine_cls = EmbeddingEngine
24

25
    async def __init__(
26
        self, base_config: EmbeddingApp, *, engine_cls=None, generation_request_cls=None
27
    ):
28
        self.base_config = base_config
29
        self.config_store = {}  # type: ignore
30

31
        engine_cls = engine_cls or self._default_engine_cls
32
        self._generation_request_cls = (
33
            generation_request_cls or self._generation_request_cls
34
        )
35

36
        self.stream.set_max_batch_size(base_config.engine_config.max_batch_size)
37
        self.stream.set_batch_wait_timeout_s(
38
            base_config.engine_config.batch_wait_timeout_s
39
        )
40

41
        self.engine = engine_cls(base_config)
42
        await self.engine.start()
43

44
    def _parse_response(
45
        self, response: Union[AviaryModelResponse, Exception], request_id: str
46
    ):
47
        if isinstance(response, Exception):
48
            return get_response_for_error(response, request_id=request_id)
49
        return response
50

51
    @serve.batch(max_batch_size=1, batch_wait_timeout_s=0.1)
52
    async def stream(
53
        self,
54
        request_ids: List[str],
55
        prompts: List[Prompt],
56
        priorities: Optional[List[QueuePriority]] = None,
57
    ):
58
        """A thin wrapper around EmbeddingEngine.generate().
59
        1. Load the model to disk
60
        2. Format parameters correctly
61
        3. Forward request to EmbeddingEngine.generate()
62
        """
63
        if not priorities:
64
            priorities = [QueuePriority.GENERATE_TEXT for _ in request_ids]
65

66
        prompt_texts = [prompt.prompt for prompt in prompts]
67

68
        logger.info(
69
            f"Received streaming requests ({len(request_ids)}) {','.join(request_ids)}"
70
        )
71
        embedding_request = self._generation_request_cls(
72
            prompt=prompt_texts,
73
            request_id=request_ids,
74
            sampling_params=None,
75
            scheduling_metadata=SchedulingMetadata(
76
                request_id=request_ids, priority=priorities
77
            ),
78
        )
79

80
        async for batched_aviary_model_response in self.engine.generate(
81
            embedding_request
82
        ):
83
            logger.info(
84
                f"Finished generating for streaming requests ({len(request_ids)}) {','.join(request_ids)}"
85
            )
86
            yield [
87
                self._parse_response(response, request_id)
88
                for response, request_id in zip(
89
                    batched_aviary_model_response, request_ids
90
                )
91
            ]
92

93
    async def check_health(self):
94
        return await self.engine.check_health()
95

96

97
@serve.deployment(
98
    # TODO make this configurable in aviary run
99
    autoscaling_config={
100
        "min_replicas": 1,
101
        "initial_replicas": 1,
102
        "max_replicas": 10,
103
        "target_num_ongoing_requests_per_replica": int(
104
            os.environ.get("AVIARY_ROUTER_TARGET_NUM_ONGOING_REQUESTS_PER_REPLICA", 10)
105
        ),
106
    },
107
    max_concurrent_queries=20,  # Maximum backlog for a single replica
108
    health_check_period_s=30,
109
    health_check_timeout_s=30,
110
)
111
class EmbeddingDeployment(EmbeddingDeploymentImpl):
112
    ...
113

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.