llama-index
104 строки · 3.6 Кб
1from typing import Any, List, Optional2
3from llama_index.legacy.bridge.pydantic import Field, PrivateAttr4from llama_index.legacy.callbacks import CallbackManager5from llama_index.legacy.core.embeddings.base import (6DEFAULT_EMBED_BATCH_SIZE,7BaseEmbedding,8)
9from llama_index.legacy.embeddings.huggingface_utils import (10DEFAULT_INSTRUCT_MODEL,11get_query_instruct_for_model_name,12get_text_instruct_for_model_name,13)
14
15
16class InstructorEmbedding(BaseEmbedding):17query_instruction: Optional[str] = Field(18description="Instruction to prepend to query text."19)20text_instruction: Optional[str] = Field(21description="Instruction to prepend to text."22)23cache_folder: Optional[str] = Field(24description="Cache folder for huggingface files."25)26
27_model: Any = PrivateAttr()28
29def __init__(30self,31model_name: str = DEFAULT_INSTRUCT_MODEL,32query_instruction: Optional[str] = None,33text_instruction: Optional[str] = None,34embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,35cache_folder: Optional[str] = None,36device: Optional[str] = None,37callback_manager: Optional[CallbackManager] = None,38):39try:40from InstructorEmbedding import INSTRUCTOR41except ImportError:42raise ImportError(43"InstructorEmbedding requires instructor to be installed.\n"44"Please install transformers with `pip install InstructorEmbedding`."45)46self._model = INSTRUCTOR(model_name, cache_folder=cache_folder, device=device)47
48super().__init__(49embed_batch_size=embed_batch_size,50callback_manager=callback_manager,51model_name=model_name,52query_instruction=query_instruction,53text_instruction=text_instruction,54cache_folder=cache_folder,55)56
57@classmethod58def class_name(cls) -> str:59return "InstructorEmbedding"60
61def _format_query_text(self, query_text: str) -> List[str]:62"""Format query text."""63instruction = self.text_instruction64
65if instruction is None:66instruction = get_query_instruct_for_model_name(self.model_name)67
68return [instruction, query_text]69
70def _format_text(self, text: str) -> List[str]:71"""Format text."""72instruction = self.text_instruction73
74if instruction is None:75instruction = get_text_instruct_for_model_name(self.model_name)76
77return [instruction, text]78
79def _embed(self, instruct_sentence_pairs: List[List[str]]) -> List[List[float]]:80"""Embed sentences."""81return self._model.encode(instruct_sentence_pairs).tolist()82
83def _get_query_embedding(self, query: str) -> List[float]:84"""Get query embedding."""85query_pair = self._format_query_text(query)86return self._embed([query_pair])[0]87
88async def _aget_query_embedding(self, query: str) -> List[float]:89"""Get query embedding async."""90return self._get_query_embedding(query)91
92async def _aget_text_embedding(self, text: str) -> List[float]:93"""Get text embedding async."""94return self._get_text_embedding(text)95
96def _get_text_embedding(self, text: str) -> List[float]:97"""Get text embedding."""98text_pair = self._format_text(text)99return self._embed([text_pair])[0]100
101def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:102"""Get text embeddings."""103text_pairs = [self._format_text(text) for text in texts]104return self._embed(text_pairs)105