llama-index

Форк
0
118 строк · 3.7 Кб
1
import logging
2
from typing import Any, List
3

4
import requests
5
from requests.adapters import HTTPAdapter, Retry
6

7
from llama_index.legacy.embeddings.base import BaseEmbedding
8

9
logger = logging.getLogger(__name__)
10

11

12
class LLMRailsEmbedding(BaseEmbedding):
13
    """LLMRails embedding models.
14

15
    This class provides an interface to generate embeddings using a model deployed
16
    in an LLMRails cluster. It requires a model_id of the model deployed in the cluster and api key you can obtain
17
    from https://console.llmrails.com/api-keys.
18

19
    """
20

21
    model_id: str
22
    api_key: str
23
    session: requests.Session
24

25
    @classmethod
26
    def class_name(self) -> str:
27
        return "LLMRailsEmbedding"
28

29
    def __init__(
30
        self,
31
        api_key: str,
32
        model_id: str = "embedding-english-v1",  # or embedding-multi-v1
33
        **kwargs: Any,
34
    ):
35
        retry = Retry(
36
            total=3,
37
            connect=3,
38
            read=2,
39
            allowed_methods=["POST"],
40
            backoff_factor=2,
41
            status_forcelist=[502, 503, 504],
42
        )
43
        session = requests.Session()
44
        session.mount("https://api.llmrails.com", HTTPAdapter(max_retries=retry))
45
        session.headers = {"X-API-KEY": api_key}
46
        super().__init__(model_id=model_id, api_key=api_key, session=session, **kwargs)
47

48
    def _get_embedding(self, text: str) -> List[float]:
49
        """
50
        Generate an embedding for a single query text.
51

52
        Args:
53
            text (str): The query text to generate an embedding for.
54

55
        Returns:
56
            List[float]: The embedding for the input query text.
57
        """
58
        try:
59
            response = self.session.post(
60
                "https://api.llmrails.com/v1/embeddings",
61
                json={"input": [text], "model": self.model_id},
62
            )
63

64
            response.raise_for_status()
65
            return response.json()["data"][0]["embedding"]
66

67
        except requests.exceptions.HTTPError as e:
68
            logger.error(f"Error while embedding text {e}.")
69
            raise ValueError(f"Unable to embed given text {e}")
70

71
    async def _aget_embedding(self, text: str) -> List[float]:
72
        """
73
        Generate an embedding for a single query text.
74

75
        Args:
76
            text (str): The query text to generate an embedding for.
77

78
        Returns:
79
            List[float]: The embedding for the input query text.
80
        """
81
        try:
82
            import httpx
83
        except ImportError:
84
            raise ImportError(
85
                "The httpx library is required to use the async version of "
86
                "this function. Install it with `pip install httpx`."
87
            )
88

89
        try:
90
            async with httpx.AsyncClient() as client:
91
                response = await client.post(
92
                    "https://api.llmrails.com/v1/embeddings",
93
                    headers={"X-API-KEY": self.api_key},
94
                    json={"input": [text], "model": self.model_id},
95
                )
96

97
                response.raise_for_status()
98

99
            return response.json()["data"][0]["embedding"]
100

101
        except httpx._exceptions.HTTPError as e:
102
            logger.error(f"Error while embedding text {e}.")
103
            raise ValueError(f"Unable to embed given text {e}")
104

105
    def _get_text_embedding(self, text: str) -> List[float]:
106
        return self._get_embedding(text)
107

108
    def _get_query_embedding(self, query: str) -> List[float]:
109
        return self._get_embedding(query)
110

111
    async def _aget_query_embedding(self, query: str) -> List[float]:
112
        return await self._aget_embedding(query)
113

114
    async def _aget_text_embedding(self, query: str) -> List[float]:
115
        return await self._aget_embedding(query)
116

117

118
LLMRailsEmbeddings = LLMRailsEmbedding
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.