llama-index
318 строк · 11.3 Кб
1import asyncio2from typing import TYPE_CHECKING, Any, List, Optional, Sequence3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr5from llama_index.legacy.callbacks import CallbackManager6from llama_index.legacy.core.embeddings.base import (7DEFAULT_EMBED_BATCH_SIZE,8BaseEmbedding,9Embedding,10)
11from llama_index.legacy.embeddings.huggingface_utils import (12DEFAULT_HUGGINGFACE_EMBEDDING_MODEL,13format_query,14format_text,15get_pooling_mode,16)
17from llama_index.legacy.embeddings.pooling import Pooling18from llama_index.legacy.llms.huggingface import HuggingFaceInferenceAPI19from llama_index.legacy.utils import get_cache_dir, infer_torch_device20
21if TYPE_CHECKING:22import torch23
24DEFAULT_HUGGINGFACE_LENGTH = 51225
26
27class HuggingFaceEmbedding(BaseEmbedding):28tokenizer_name: str = Field(description="Tokenizer name from HuggingFace.")29max_length: int = Field(30default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=031)32pooling: Pooling = Field(default=None, description="Pooling strategy.")33normalize: bool = Field(default=True, description="Normalize embeddings or not.")34query_instruction: Optional[str] = Field(35description="Instruction to prepend to query text."36)37text_instruction: Optional[str] = Field(38description="Instruction to prepend to text."39)40cache_folder: Optional[str] = Field(41description="Cache folder for huggingface files."42)43
44_model: Any = PrivateAttr()45_tokenizer: Any = PrivateAttr()46_device: str = PrivateAttr()47
48def __init__(49self,50model_name: Optional[str] = None,51tokenizer_name: Optional[str] = None,52pooling: Optional[str] = None,53max_length: Optional[int] = None,54query_instruction: Optional[str] = None,55text_instruction: Optional[str] = None,56normalize: bool = True,57model: Optional[Any] = None,58tokenizer: Optional[Any] = None,59embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,60cache_folder: Optional[str] = None,61trust_remote_code: bool = False,62device: Optional[str] = None,63callback_manager: Optional[CallbackManager] = None,64):65try:66from transformers import AutoModel, AutoTokenizer67except ImportError:68raise ImportError(69"HuggingFaceEmbedding requires transformers to be installed.\n"70"Please install transformers with `pip install transformers`."71)72
73self._device = device or infer_torch_device()74
75cache_folder = cache_folder or get_cache_dir()76
77if model is None: # Use model_name with AutoModel78model_name = (79model_name
80if model_name is not None81else DEFAULT_HUGGINGFACE_EMBEDDING_MODEL82)83model = AutoModel.from_pretrained(84model_name, cache_dir=cache_folder, trust_remote_code=trust_remote_code85)86elif model_name is None: # Extract model_name from model87model_name = model.name_or_path88self._model = model.to(self._device)89
90if tokenizer is None: # Use tokenizer_name with AutoTokenizer91tokenizer_name = (92model_name or tokenizer_name or DEFAULT_HUGGINGFACE_EMBEDDING_MODEL93)94tokenizer = AutoTokenizer.from_pretrained(95tokenizer_name, cache_dir=cache_folder96)97elif tokenizer_name is None: # Extract tokenizer_name from model98tokenizer_name = tokenizer.name_or_path99self._tokenizer = tokenizer100
101if max_length is None:102try:103max_length = int(self._model.config.max_position_embeddings)104except AttributeError as exc:105raise ValueError(106"Unable to find max_length from model config. Please specify max_length."107) from exc108
109if not pooling:110pooling = get_pooling_mode(model_name)111try:112pooling = Pooling(pooling)113except ValueError as exc:114raise NotImplementedError(115f"Pooling {pooling} unsupported, please pick one in"116f" {[p.value for p in Pooling]}."117) from exc118
119super().__init__(120embed_batch_size=embed_batch_size,121callback_manager=callback_manager,122model_name=model_name,123tokenizer_name=tokenizer_name,124max_length=max_length,125pooling=pooling,126normalize=normalize,127query_instruction=query_instruction,128text_instruction=text_instruction,129)130
131@classmethod132def class_name(cls) -> str:133return "HuggingFaceEmbedding"134
135def _mean_pooling(136self, token_embeddings: "torch.Tensor", attention_mask: "torch.Tensor"137) -> "torch.Tensor":138"""Mean Pooling - Take attention mask into account for correct averaging."""139input_mask_expanded = (140attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()141)142numerator = (token_embeddings * input_mask_expanded).sum(1)143return numerator / input_mask_expanded.sum(1).clamp(min=1e-9)144
145def _embed(self, sentences: List[str]) -> List[List[float]]:146"""Embed sentences."""147encoded_input = self._tokenizer(148sentences,149padding=True,150max_length=self.max_length,151truncation=True,152return_tensors="pt",153)154
155# pop token_type_ids156encoded_input.pop("token_type_ids", None)157
158# move tokenizer inputs to device159encoded_input = {160key: val.to(self._device) for key, val in encoded_input.items()161}162
163model_output = self._model(**encoded_input)164
165if self.pooling == Pooling.CLS:166context_layer: "torch.Tensor" = model_output[0]167embeddings = self.pooling.cls_pooling(context_layer)168else:169embeddings = self._mean_pooling(170token_embeddings=model_output[0],171attention_mask=encoded_input["attention_mask"],172)173
174if self.normalize:175import torch176
177embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)178
179return embeddings.tolist()180
181def _get_query_embedding(self, query: str) -> List[float]:182"""Get query embedding."""183query = format_query(query, self.model_name, self.query_instruction)184return self._embed([query])[0]185
186async def _aget_query_embedding(self, query: str) -> List[float]:187"""Get query embedding async."""188return self._get_query_embedding(query)189
190async def _aget_text_embedding(self, text: str) -> List[float]:191"""Get text embedding async."""192return self._get_text_embedding(text)193
194def _get_text_embedding(self, text: str) -> List[float]:195"""Get text embedding."""196text = format_text(text, self.model_name, self.text_instruction)197return self._embed([text])[0]198
199def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:200"""Get text embeddings."""201texts = [202format_text(text, self.model_name, self.text_instruction) for text in texts203]204return self._embed(texts)205
206
207class HuggingFaceInferenceAPIEmbedding(HuggingFaceInferenceAPI, BaseEmbedding): # type: ignore[misc]208"""209Wrapper on the Hugging Face's Inference API for embeddings.
210
211Overview of the design:
212- Uses the feature extraction task: https://huggingface.co/tasks/feature-extraction
213"""
214
215pooling: Optional[Pooling] = Field(216default=Pooling.CLS,217description=(218"Optional pooling technique to use with embeddings capability, if"219" the model's raw output needs pooling."220),221)222query_instruction: Optional[str] = Field(223default=None,224description=(225"Instruction to prepend during query embedding."226" Use of None means infer the instruction based on the model."227" Use of empty string will defeat instruction prepending entirely."228),229)230text_instruction: Optional[str] = Field(231default=None,232description=(233"Instruction to prepend during text embedding."234" Use of None means infer the instruction based on the model."235" Use of empty string will defeat instruction prepending entirely."236),237)238
239@classmethod240def class_name(cls) -> str:241return "HuggingFaceInferenceAPIEmbedding"242
243async def _async_embed_single(self, text: str) -> Embedding:244embedding = await self._async_client.feature_extraction(text)245if len(embedding.shape) == 1:246return embedding.tolist()247embedding = embedding.squeeze(axis=0)248if len(embedding.shape) == 1: # Some models pool internally249return embedding.tolist()250try:251return self.pooling(embedding).tolist() # type: ignore[misc]252except TypeError as exc:253raise ValueError(254f"Pooling is required for {self.model_name} because it returned"255" a > 1-D value, please specify pooling as not None."256) from exc257
258async def _async_embed_bulk(self, texts: Sequence[str]) -> List[Embedding]:259"""260Embed a sequence of text, in parallel and asynchronously.
261
262NOTE: this uses an externally created asyncio event loop.
263"""
264tasks = [self._async_embed_single(text) for text in texts]265return await asyncio.gather(*tasks)266
267def _get_query_embedding(self, query: str) -> Embedding:268"""269Embed the input query synchronously.
270
271NOTE: a new asyncio event loop is created internally for this.
272"""
273return asyncio.run(self._aget_query_embedding(query))274
275def _get_text_embedding(self, text: str) -> Embedding:276"""277Embed the text query synchronously.
278
279NOTE: a new asyncio event loop is created internally for this.
280"""
281return asyncio.run(self._aget_text_embedding(text))282
283def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:284"""285Embed the input sequence of text synchronously and in parallel.
286
287NOTE: a new asyncio event loop is created internally for this.
288"""
289loop = asyncio.new_event_loop()290try:291tasks = [292loop.create_task(self._aget_text_embedding(text)) for text in texts293]294loop.run_until_complete(asyncio.wait(tasks))295finally:296loop.close()297return [task.result() for task in tasks]298
299async def _aget_query_embedding(self, query: str) -> Embedding:300return await self._async_embed_single(301text=format_query(query, self.model_name, self.query_instruction)302)303
304async def _aget_text_embedding(self, text: str) -> Embedding:305return await self._async_embed_single(306text=format_text(text, self.model_name, self.text_instruction)307)308
309async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:310return await self._async_embed_bulk(311texts=[312format_text(text, self.model_name, self.text_instruction)313for text in texts314]315)316
317
318HuggingFaceInferenceAPIEmbeddings = HuggingFaceInferenceAPIEmbedding319