llama-index

Форк
0
148 строк · 5.2 Кб
1
from typing import Callable, List, Optional, Union
2

3
from llama_index.legacy.bridge.pydantic import Field
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.core.embeddings.base import (
6
    DEFAULT_EMBED_BATCH_SIZE,
7
    BaseEmbedding,
8
    Embedding,
9
)
10
from llama_index.legacy.embeddings.huggingface_utils import format_query, format_text
11

12
DEFAULT_URL = "http://127.0.0.1:8080"
13

14

15
class TextEmbeddingsInference(BaseEmbedding):
16
    base_url: str = Field(
17
        default=DEFAULT_URL,
18
        description="Base URL for the text embeddings service.",
19
    )
20
    query_instruction: Optional[str] = Field(
21
        description="Instruction to prepend to query text."
22
    )
23
    text_instruction: Optional[str] = Field(
24
        description="Instruction to prepend to text."
25
    )
26
    timeout: float = Field(
27
        default=60.0,
28
        description="Timeout in seconds for the request.",
29
    )
30
    truncate_text: bool = Field(
31
        default=True,
32
        description="Whether to truncate text or not when generating embeddings.",
33
    )
34
    auth_token: Optional[Union[str, Callable[[str], str]]] = Field(
35
        default=None,
36
        description="Authentication token or authentication token generating function for authenticated requests",
37
    )
38

39
    def __init__(
40
        self,
41
        model_name: str,
42
        base_url: str = DEFAULT_URL,
43
        text_instruction: Optional[str] = None,
44
        query_instruction: Optional[str] = None,
45
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
46
        timeout: float = 60.0,
47
        truncate_text: bool = True,
48
        callback_manager: Optional[CallbackManager] = None,
49
        auth_token: Optional[Union[str, Callable[[str], str]]] = None,
50
    ):
51
        try:
52
            import httpx  # noqa
53
        except ImportError:
54
            raise ImportError(
55
                "TextEmbeddingsInterface requires httpx to be installed.\n"
56
                "Please install httpx with `pip install httpx`."
57
            )
58

59
        super().__init__(
60
            base_url=base_url,
61
            model_name=model_name,
62
            text_instruction=text_instruction,
63
            query_instruction=query_instruction,
64
            embed_batch_size=embed_batch_size,
65
            timeout=timeout,
66
            truncate_text=truncate_text,
67
            callback_manager=callback_manager,
68
            auth_token=auth_token,
69
        )
70

71
    @classmethod
72
    def class_name(cls) -> str:
73
        return "TextEmbeddingsInference"
74

75
    def _call_api(self, texts: List[str]) -> List[List[float]]:
76
        import httpx
77

78
        headers = {"Content-Type": "application/json"}
79
        if self.auth_token is not None:
80
            if callable(self.auth_token):
81
                headers["Authorization"] = self.auth_token(self.base_url)
82
            else:
83
                headers["Authorization"] = self.auth_token
84
        json_data = {"inputs": texts, "truncate": self.truncate_text}
85

86
        with httpx.Client() as client:
87
            response = client.post(
88
                f"{self.base_url}/embed",
89
                headers=headers,
90
                json=json_data,
91
                timeout=self.timeout,
92
            )
93

94
        return response.json()
95

96
    async def _acall_api(self, texts: List[str]) -> List[List[float]]:
97
        import httpx
98

99
        headers = {"Content-Type": "application/json"}
100
        if self.auth_token is not None:
101
            if callable(self.auth_token):
102
                headers["Authorization"] = self.auth_token(self.base_url)
103
            else:
104
                headers["Authorization"] = self.auth_token
105
        json_data = {"inputs": texts, "truncate": self.truncate_text}
106

107
        async with httpx.AsyncClient() as client:
108
            response = await client.post(
109
                f"{self.base_url}/embed",
110
                headers=headers,
111
                json=json_data,
112
                timeout=self.timeout,
113
            )
114

115
        return response.json()
116

117
    def _get_query_embedding(self, query: str) -> List[float]:
118
        """Get query embedding."""
119
        query = format_query(query, self.model_name, self.query_instruction)
120
        return self._call_api([query])[0]
121

122
    def _get_text_embedding(self, text: str) -> List[float]:
123
        """Get text embedding."""
124
        text = format_text(text, self.model_name, self.text_instruction)
125
        return self._call_api([text])[0]
126

127
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
128
        """Get text embeddings."""
129
        texts = [
130
            format_text(text, self.model_name, self.text_instruction) for text in texts
131
        ]
132
        return self._call_api(texts)
133

134
    async def _aget_query_embedding(self, query: str) -> List[float]:
135
        """Get query embedding async."""
136
        query = format_query(query, self.model_name, self.query_instruction)
137
        return (await self._acall_api([query]))[0]
138

139
    async def _aget_text_embedding(self, text: str) -> List[float]:
140
        """Get text embedding async."""
141
        text = format_text(text, self.model_name, self.text_instruction)
142
        return (await self._acall_api([text]))[0]
143

144
    async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
145
        texts = [
146
            format_text(text, self.model_name, self.text_instruction) for text in texts
147
        ]
148
        return await self._acall_api(texts)
149

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.