llama-index

Форк
0
318 строк · 11.3 Кб
1
import asyncio
2
from typing import TYPE_CHECKING, Any, List, Optional, Sequence
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.core.embeddings.base import (
7
    DEFAULT_EMBED_BATCH_SIZE,
8
    BaseEmbedding,
9
    Embedding,
10
)
11
from llama_index.legacy.embeddings.huggingface_utils import (
12
    DEFAULT_HUGGINGFACE_EMBEDDING_MODEL,
13
    format_query,
14
    format_text,
15
    get_pooling_mode,
16
)
17
from llama_index.legacy.embeddings.pooling import Pooling
18
from llama_index.legacy.llms.huggingface import HuggingFaceInferenceAPI
19
from llama_index.legacy.utils import get_cache_dir, infer_torch_device
20

21
if TYPE_CHECKING:
22
    import torch
23

24
DEFAULT_HUGGINGFACE_LENGTH = 512
25

26

27
class HuggingFaceEmbedding(BaseEmbedding):
28
    tokenizer_name: str = Field(description="Tokenizer name from HuggingFace.")
29
    max_length: int = Field(
30
        default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0
31
    )
32
    pooling: Pooling = Field(default=None, description="Pooling strategy.")
33
    normalize: bool = Field(default=True, description="Normalize embeddings or not.")
34
    query_instruction: Optional[str] = Field(
35
        description="Instruction to prepend to query text."
36
    )
37
    text_instruction: Optional[str] = Field(
38
        description="Instruction to prepend to text."
39
    )
40
    cache_folder: Optional[str] = Field(
41
        description="Cache folder for huggingface files."
42
    )
43

44
    _model: Any = PrivateAttr()
45
    _tokenizer: Any = PrivateAttr()
46
    _device: str = PrivateAttr()
47

48
    def __init__(
49
        self,
50
        model_name: Optional[str] = None,
51
        tokenizer_name: Optional[str] = None,
52
        pooling: Optional[str] = None,
53
        max_length: Optional[int] = None,
54
        query_instruction: Optional[str] = None,
55
        text_instruction: Optional[str] = None,
56
        normalize: bool = True,
57
        model: Optional[Any] = None,
58
        tokenizer: Optional[Any] = None,
59
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
60
        cache_folder: Optional[str] = None,
61
        trust_remote_code: bool = False,
62
        device: Optional[str] = None,
63
        callback_manager: Optional[CallbackManager] = None,
64
    ):
65
        try:
66
            from transformers import AutoModel, AutoTokenizer
67
        except ImportError:
68
            raise ImportError(
69
                "HuggingFaceEmbedding requires transformers to be installed.\n"
70
                "Please install transformers with `pip install transformers`."
71
            )
72

73
        self._device = device or infer_torch_device()
74

75
        cache_folder = cache_folder or get_cache_dir()
76

77
        if model is None:  # Use model_name with AutoModel
78
            model_name = (
79
                model_name
80
                if model_name is not None
81
                else DEFAULT_HUGGINGFACE_EMBEDDING_MODEL
82
            )
83
            model = AutoModel.from_pretrained(
84
                model_name, cache_dir=cache_folder, trust_remote_code=trust_remote_code
85
            )
86
        elif model_name is None:  # Extract model_name from model
87
            model_name = model.name_or_path
88
        self._model = model.to(self._device)
89

90
        if tokenizer is None:  # Use tokenizer_name with AutoTokenizer
91
            tokenizer_name = (
92
                model_name or tokenizer_name or DEFAULT_HUGGINGFACE_EMBEDDING_MODEL
93
            )
94
            tokenizer = AutoTokenizer.from_pretrained(
95
                tokenizer_name, cache_dir=cache_folder
96
            )
97
        elif tokenizer_name is None:  # Extract tokenizer_name from model
98
            tokenizer_name = tokenizer.name_or_path
99
        self._tokenizer = tokenizer
100

101
        if max_length is None:
102
            try:
103
                max_length = int(self._model.config.max_position_embeddings)
104
            except AttributeError as exc:
105
                raise ValueError(
106
                    "Unable to find max_length from model config. Please specify max_length."
107
                ) from exc
108

109
        if not pooling:
110
            pooling = get_pooling_mode(model_name)
111
        try:
112
            pooling = Pooling(pooling)
113
        except ValueError as exc:
114
            raise NotImplementedError(
115
                f"Pooling {pooling} unsupported, please pick one in"
116
                f" {[p.value for p in Pooling]}."
117
            ) from exc
118

119
        super().__init__(
120
            embed_batch_size=embed_batch_size,
121
            callback_manager=callback_manager,
122
            model_name=model_name,
123
            tokenizer_name=tokenizer_name,
124
            max_length=max_length,
125
            pooling=pooling,
126
            normalize=normalize,
127
            query_instruction=query_instruction,
128
            text_instruction=text_instruction,
129
        )
130

131
    @classmethod
132
    def class_name(cls) -> str:
133
        return "HuggingFaceEmbedding"
134

135
    def _mean_pooling(
136
        self, token_embeddings: "torch.Tensor", attention_mask: "torch.Tensor"
137
    ) -> "torch.Tensor":
138
        """Mean Pooling - Take attention mask into account for correct averaging."""
139
        input_mask_expanded = (
140
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
141
        )
142
        numerator = (token_embeddings * input_mask_expanded).sum(1)
143
        return numerator / input_mask_expanded.sum(1).clamp(min=1e-9)
144

145
    def _embed(self, sentences: List[str]) -> List[List[float]]:
146
        """Embed sentences."""
147
        encoded_input = self._tokenizer(
148
            sentences,
149
            padding=True,
150
            max_length=self.max_length,
151
            truncation=True,
152
            return_tensors="pt",
153
        )
154

155
        # pop token_type_ids
156
        encoded_input.pop("token_type_ids", None)
157

158
        # move tokenizer inputs to device
159
        encoded_input = {
160
            key: val.to(self._device) for key, val in encoded_input.items()
161
        }
162

163
        model_output = self._model(**encoded_input)
164

165
        if self.pooling == Pooling.CLS:
166
            context_layer: "torch.Tensor" = model_output[0]
167
            embeddings = self.pooling.cls_pooling(context_layer)
168
        else:
169
            embeddings = self._mean_pooling(
170
                token_embeddings=model_output[0],
171
                attention_mask=encoded_input["attention_mask"],
172
            )
173

174
        if self.normalize:
175
            import torch
176

177
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
178

179
        return embeddings.tolist()
180

181
    def _get_query_embedding(self, query: str) -> List[float]:
182
        """Get query embedding."""
183
        query = format_query(query, self.model_name, self.query_instruction)
184
        return self._embed([query])[0]
185

186
    async def _aget_query_embedding(self, query: str) -> List[float]:
187
        """Get query embedding async."""
188
        return self._get_query_embedding(query)
189

190
    async def _aget_text_embedding(self, text: str) -> List[float]:
191
        """Get text embedding async."""
192
        return self._get_text_embedding(text)
193

194
    def _get_text_embedding(self, text: str) -> List[float]:
195
        """Get text embedding."""
196
        text = format_text(text, self.model_name, self.text_instruction)
197
        return self._embed([text])[0]
198

199
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
200
        """Get text embeddings."""
201
        texts = [
202
            format_text(text, self.model_name, self.text_instruction) for text in texts
203
        ]
204
        return self._embed(texts)
205

206

207
class HuggingFaceInferenceAPIEmbedding(HuggingFaceInferenceAPI, BaseEmbedding):  # type: ignore[misc]
208
    """
209
    Wrapper on the Hugging Face's Inference API for embeddings.
210

211
    Overview of the design:
212
    - Uses the feature extraction task: https://huggingface.co/tasks/feature-extraction
213
    """
214

215
    pooling: Optional[Pooling] = Field(
216
        default=Pooling.CLS,
217
        description=(
218
            "Optional pooling technique to use with embeddings capability, if"
219
            " the model's raw output needs pooling."
220
        ),
221
    )
222
    query_instruction: Optional[str] = Field(
223
        default=None,
224
        description=(
225
            "Instruction to prepend during query embedding."
226
            " Use of None means infer the instruction based on the model."
227
            " Use of empty string will defeat instruction prepending entirely."
228
        ),
229
    )
230
    text_instruction: Optional[str] = Field(
231
        default=None,
232
        description=(
233
            "Instruction to prepend during text embedding."
234
            " Use of None means infer the instruction based on the model."
235
            " Use of empty string will defeat instruction prepending entirely."
236
        ),
237
    )
238

239
    @classmethod
240
    def class_name(cls) -> str:
241
        return "HuggingFaceInferenceAPIEmbedding"
242

243
    async def _async_embed_single(self, text: str) -> Embedding:
244
        embedding = await self._async_client.feature_extraction(text)
245
        if len(embedding.shape) == 1:
246
            return embedding.tolist()
247
        embedding = embedding.squeeze(axis=0)
248
        if len(embedding.shape) == 1:  # Some models pool internally
249
            return embedding.tolist()
250
        try:
251
            return self.pooling(embedding).tolist()  # type: ignore[misc]
252
        except TypeError as exc:
253
            raise ValueError(
254
                f"Pooling is required for {self.model_name} because it returned"
255
                " a > 1-D value, please specify pooling as not None."
256
            ) from exc
257

258
    async def _async_embed_bulk(self, texts: Sequence[str]) -> List[Embedding]:
259
        """
260
        Embed a sequence of text, in parallel and asynchronously.
261

262
        NOTE: this uses an externally created asyncio event loop.
263
        """
264
        tasks = [self._async_embed_single(text) for text in texts]
265
        return await asyncio.gather(*tasks)
266

267
    def _get_query_embedding(self, query: str) -> Embedding:
268
        """
269
        Embed the input query synchronously.
270

271
        NOTE: a new asyncio event loop is created internally for this.
272
        """
273
        return asyncio.run(self._aget_query_embedding(query))
274

275
    def _get_text_embedding(self, text: str) -> Embedding:
276
        """
277
        Embed the text query synchronously.
278

279
        NOTE: a new asyncio event loop is created internally for this.
280
        """
281
        return asyncio.run(self._aget_text_embedding(text))
282

283
    def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
284
        """
285
        Embed the input sequence of text synchronously and in parallel.
286

287
        NOTE: a new asyncio event loop is created internally for this.
288
        """
289
        loop = asyncio.new_event_loop()
290
        try:
291
            tasks = [
292
                loop.create_task(self._aget_text_embedding(text)) for text in texts
293
            ]
294
            loop.run_until_complete(asyncio.wait(tasks))
295
        finally:
296
            loop.close()
297
        return [task.result() for task in tasks]
298

299
    async def _aget_query_embedding(self, query: str) -> Embedding:
300
        return await self._async_embed_single(
301
            text=format_query(query, self.model_name, self.query_instruction)
302
        )
303

304
    async def _aget_text_embedding(self, text: str) -> Embedding:
305
        return await self._async_embed_single(
306
            text=format_text(text, self.model_name, self.text_instruction)
307
        )
308

309
    async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
310
        return await self._async_embed_bulk(
311
            texts=[
312
                format_text(text, self.model_name, self.text_instruction)
313
                for text in texts
314
            ]
315
        )
316

317

318
HuggingFaceInferenceAPIEmbeddings = HuggingFaceInferenceAPIEmbedding
319

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.