llama-index
148 строк · 5.2 Кб
1from typing import Callable, List, Optional, Union2
3from llama_index.legacy.bridge.pydantic import Field4from llama_index.legacy.callbacks import CallbackManager5from llama_index.legacy.core.embeddings.base import (6DEFAULT_EMBED_BATCH_SIZE,7BaseEmbedding,8Embedding,9)
10from llama_index.legacy.embeddings.huggingface_utils import format_query, format_text11
12DEFAULT_URL = "http://127.0.0.1:8080"13
14
15class TextEmbeddingsInference(BaseEmbedding):16base_url: str = Field(17default=DEFAULT_URL,18description="Base URL for the text embeddings service.",19)20query_instruction: Optional[str] = Field(21description="Instruction to prepend to query text."22)23text_instruction: Optional[str] = Field(24description="Instruction to prepend to text."25)26timeout: float = Field(27default=60.0,28description="Timeout in seconds for the request.",29)30truncate_text: bool = Field(31default=True,32description="Whether to truncate text or not when generating embeddings.",33)34auth_token: Optional[Union[str, Callable[[str], str]]] = Field(35default=None,36description="Authentication token or authentication token generating function for authenticated requests",37)38
39def __init__(40self,41model_name: str,42base_url: str = DEFAULT_URL,43text_instruction: Optional[str] = None,44query_instruction: Optional[str] = None,45embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,46timeout: float = 60.0,47truncate_text: bool = True,48callback_manager: Optional[CallbackManager] = None,49auth_token: Optional[Union[str, Callable[[str], str]]] = None,50):51try:52import httpx # noqa53except ImportError:54raise ImportError(55"TextEmbeddingsInterface requires httpx to be installed.\n"56"Please install httpx with `pip install httpx`."57)58
59super().__init__(60base_url=base_url,61model_name=model_name,62text_instruction=text_instruction,63query_instruction=query_instruction,64embed_batch_size=embed_batch_size,65timeout=timeout,66truncate_text=truncate_text,67callback_manager=callback_manager,68auth_token=auth_token,69)70
71@classmethod72def class_name(cls) -> str:73return "TextEmbeddingsInference"74
75def _call_api(self, texts: List[str]) -> List[List[float]]:76import httpx77
78headers = {"Content-Type": "application/json"}79if self.auth_token is not None:80if callable(self.auth_token):81headers["Authorization"] = self.auth_token(self.base_url)82else:83headers["Authorization"] = self.auth_token84json_data = {"inputs": texts, "truncate": self.truncate_text}85
86with httpx.Client() as client:87response = client.post(88f"{self.base_url}/embed",89headers=headers,90json=json_data,91timeout=self.timeout,92)93
94return response.json()95
96async def _acall_api(self, texts: List[str]) -> List[List[float]]:97import httpx98
99headers = {"Content-Type": "application/json"}100if self.auth_token is not None:101if callable(self.auth_token):102headers["Authorization"] = self.auth_token(self.base_url)103else:104headers["Authorization"] = self.auth_token105json_data = {"inputs": texts, "truncate": self.truncate_text}106
107async with httpx.AsyncClient() as client:108response = await client.post(109f"{self.base_url}/embed",110headers=headers,111json=json_data,112timeout=self.timeout,113)114
115return response.json()116
117def _get_query_embedding(self, query: str) -> List[float]:118"""Get query embedding."""119query = format_query(query, self.model_name, self.query_instruction)120return self._call_api([query])[0]121
122def _get_text_embedding(self, text: str) -> List[float]:123"""Get text embedding."""124text = format_text(text, self.model_name, self.text_instruction)125return self._call_api([text])[0]126
127def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:128"""Get text embeddings."""129texts = [130format_text(text, self.model_name, self.text_instruction) for text in texts131]132return self._call_api(texts)133
134async def _aget_query_embedding(self, query: str) -> List[float]:135"""Get query embedding async."""136query = format_query(query, self.model_name, self.query_instruction)137return (await self._acall_api([query]))[0]138
139async def _aget_text_embedding(self, text: str) -> List[float]:140"""Get text embedding async."""141text = format_text(text, self.model_name, self.text_instruction)142return (await self._acall_api([text]))[0]143
144async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:145texts = [146format_text(text, self.model_name, self.text_instruction) for text in texts147]148return await self._acall_api(texts)149