llama-index

Форк
0
104 строки · 3.6 Кб
1
from typing import Any, List, Optional
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.core.embeddings.base import (
6
    DEFAULT_EMBED_BATCH_SIZE,
7
    BaseEmbedding,
8
)
9
from llama_index.legacy.embeddings.huggingface_utils import (
10
    DEFAULT_INSTRUCT_MODEL,
11
    get_query_instruct_for_model_name,
12
    get_text_instruct_for_model_name,
13
)
14

15

16
class InstructorEmbedding(BaseEmbedding):
17
    query_instruction: Optional[str] = Field(
18
        description="Instruction to prepend to query text."
19
    )
20
    text_instruction: Optional[str] = Field(
21
        description="Instruction to prepend to text."
22
    )
23
    cache_folder: Optional[str] = Field(
24
        description="Cache folder for huggingface files."
25
    )
26

27
    _model: Any = PrivateAttr()
28

29
    def __init__(
30
        self,
31
        model_name: str = DEFAULT_INSTRUCT_MODEL,
32
        query_instruction: Optional[str] = None,
33
        text_instruction: Optional[str] = None,
34
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
35
        cache_folder: Optional[str] = None,
36
        device: Optional[str] = None,
37
        callback_manager: Optional[CallbackManager] = None,
38
    ):
39
        try:
40
            from InstructorEmbedding import INSTRUCTOR
41
        except ImportError:
42
            raise ImportError(
43
                "InstructorEmbedding requires instructor to be installed.\n"
44
                "Please install transformers with `pip install InstructorEmbedding`."
45
            )
46
        self._model = INSTRUCTOR(model_name, cache_folder=cache_folder, device=device)
47

48
        super().__init__(
49
            embed_batch_size=embed_batch_size,
50
            callback_manager=callback_manager,
51
            model_name=model_name,
52
            query_instruction=query_instruction,
53
            text_instruction=text_instruction,
54
            cache_folder=cache_folder,
55
        )
56

57
    @classmethod
58
    def class_name(cls) -> str:
59
        return "InstructorEmbedding"
60

61
    def _format_query_text(self, query_text: str) -> List[str]:
62
        """Format query text."""
63
        instruction = self.text_instruction
64

65
        if instruction is None:
66
            instruction = get_query_instruct_for_model_name(self.model_name)
67

68
        return [instruction, query_text]
69

70
    def _format_text(self, text: str) -> List[str]:
71
        """Format text."""
72
        instruction = self.text_instruction
73

74
        if instruction is None:
75
            instruction = get_text_instruct_for_model_name(self.model_name)
76

77
        return [instruction, text]
78

79
    def _embed(self, instruct_sentence_pairs: List[List[str]]) -> List[List[float]]:
80
        """Embed sentences."""
81
        return self._model.encode(instruct_sentence_pairs).tolist()
82

83
    def _get_query_embedding(self, query: str) -> List[float]:
84
        """Get query embedding."""
85
        query_pair = self._format_query_text(query)
86
        return self._embed([query_pair])[0]
87

88
    async def _aget_query_embedding(self, query: str) -> List[float]:
89
        """Get query embedding async."""
90
        return self._get_query_embedding(query)
91

92
    async def _aget_text_embedding(self, text: str) -> List[float]:
93
        """Get text embedding async."""
94
        return self._get_text_embedding(text)
95

96
    def _get_text_embedding(self, text: str) -> List[float]:
97
        """Get text embedding."""
98
        text_pair = self._format_text(text)
99
        return self._embed([text_pair])[0]
100

101
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
102
        """Get text embeddings."""
103
        text_pairs = [self._format_text(text) for text in texts]
104
        return self._embed(text_pairs)
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.