3
from typing import List, Optional, Union
7
from rayllm.backend.llm.embedding.embedding_engine import EmbeddingEngine
8
from rayllm.backend.llm.embedding.embedding_models import EmbeddingApp
9
from rayllm.backend.server.models import (
15
from rayllm.backend.server.utils import get_response_for_error
16
from rayllm.common.models import Prompt
18
logger = logging.getLogger(__name__)
21
class EmbeddingDeploymentImpl:
22
_generation_request_cls = GenerationRequest
23
_default_engine_cls = EmbeddingEngine
26
self, base_config: EmbeddingApp, *, engine_cls=None, generation_request_cls=None
28
self.base_config = base_config
29
self.config_store = {}
31
engine_cls = engine_cls or self._default_engine_cls
32
self._generation_request_cls = (
33
generation_request_cls or self._generation_request_cls
36
self.stream.set_max_batch_size(base_config.engine_config.max_batch_size)
37
self.stream.set_batch_wait_timeout_s(
38
base_config.engine_config.batch_wait_timeout_s
41
self.engine = engine_cls(base_config)
42
await self.engine.start()
45
self, response: Union[AviaryModelResponse, Exception], request_id: str
47
if isinstance(response, Exception):
48
return get_response_for_error(response, request_id=request_id)
51
@serve.batch(max_batch_size=1, batch_wait_timeout_s=0.1)
54
request_ids: List[str],
55
prompts: List[Prompt],
56
priorities: Optional[List[QueuePriority]] = None,
58
"""A thin wrapper around EmbeddingEngine.generate().
59
1. Load the model to disk
60
2. Format parameters correctly
61
3. Forward request to EmbeddingEngine.generate()
64
priorities = [QueuePriority.GENERATE_TEXT for _ in request_ids]
66
prompt_texts = [prompt.prompt for prompt in prompts]
69
f"Received streaming requests ({len(request_ids)}) {','.join(request_ids)}"
71
embedding_request = self._generation_request_cls(
73
request_id=request_ids,
75
scheduling_metadata=SchedulingMetadata(
76
request_id=request_ids, priority=priorities
80
async for batched_aviary_model_response in self.engine.generate(
84
f"Finished generating for streaming requests ({len(request_ids)}) {','.join(request_ids)}"
87
self._parse_response(response, request_id)
88
for response, request_id in zip(
89
batched_aviary_model_response, request_ids
93
async def check_health(self):
94
return await self.engine.check_health()
101
"initial_replicas": 1,
103
"target_num_ongoing_requests_per_replica": int(
104
os.environ.get("AVIARY_ROUTER_TARGET_NUM_ONGOING_REQUESTS_PER_REPLICA", 10)
107
max_concurrent_queries=20,
108
health_check_period_s=30,
109
health_check_timeout_s=30,
111
class EmbeddingDeployment(EmbeddingDeploymentImpl):