llama-index
118 строк · 3.7 Кб
1import logging2from typing import Any, List3
4import requests5from requests.adapters import HTTPAdapter, Retry6
7from llama_index.legacy.embeddings.base import BaseEmbedding8
9logger = logging.getLogger(__name__)10
11
12class LLMRailsEmbedding(BaseEmbedding):13"""LLMRails embedding models.14
15This class provides an interface to generate embeddings using a model deployed
16in an LLMRails cluster. It requires a model_id of the model deployed in the cluster and api key you can obtain
17from https://console.llmrails.com/api-keys.
18
19"""
20
21model_id: str22api_key: str23session: requests.Session24
25@classmethod26def class_name(self) -> str:27return "LLMRailsEmbedding"28
29def __init__(30self,31api_key: str,32model_id: str = "embedding-english-v1", # or embedding-multi-v133**kwargs: Any,34):35retry = Retry(36total=3,37connect=3,38read=2,39allowed_methods=["POST"],40backoff_factor=2,41status_forcelist=[502, 503, 504],42)43session = requests.Session()44session.mount("https://api.llmrails.com", HTTPAdapter(max_retries=retry))45session.headers = {"X-API-KEY": api_key}46super().__init__(model_id=model_id, api_key=api_key, session=session, **kwargs)47
48def _get_embedding(self, text: str) -> List[float]:49"""50Generate an embedding for a single query text.
51
52Args:
53text (str): The query text to generate an embedding for.
54
55Returns:
56List[float]: The embedding for the input query text.
57"""
58try:59response = self.session.post(60"https://api.llmrails.com/v1/embeddings",61json={"input": [text], "model": self.model_id},62)63
64response.raise_for_status()65return response.json()["data"][0]["embedding"]66
67except requests.exceptions.HTTPError as e:68logger.error(f"Error while embedding text {e}.")69raise ValueError(f"Unable to embed given text {e}")70
71async def _aget_embedding(self, text: str) -> List[float]:72"""73Generate an embedding for a single query text.
74
75Args:
76text (str): The query text to generate an embedding for.
77
78Returns:
79List[float]: The embedding for the input query text.
80"""
81try:82import httpx83except ImportError:84raise ImportError(85"The httpx library is required to use the async version of "86"this function. Install it with `pip install httpx`."87)88
89try:90async with httpx.AsyncClient() as client:91response = await client.post(92"https://api.llmrails.com/v1/embeddings",93headers={"X-API-KEY": self.api_key},94json={"input": [text], "model": self.model_id},95)96
97response.raise_for_status()98
99return response.json()["data"][0]["embedding"]100
101except httpx._exceptions.HTTPError as e:102logger.error(f"Error while embedding text {e}.")103raise ValueError(f"Unable to embed given text {e}")104
105def _get_text_embedding(self, text: str) -> List[float]:106return self._get_embedding(text)107
108def _get_query_embedding(self, query: str) -> List[float]:109return self._get_embedding(query)110
111async def _aget_query_embedding(self, query: str) -> List[float]:112return await self._aget_embedding(query)113
114async def _aget_text_embedding(self, query: str) -> List[float]:115return await self._aget_embedding(query)116
117
118LLMRailsEmbeddings = LLMRailsEmbedding119