llama-index

Форк
0
198 строк · 7.0 Кб
1
from typing import Any, List, Optional
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.core.embeddings.base import (
6
    DEFAULT_EMBED_BATCH_SIZE,
7
    BaseEmbedding,
8
)
9
from llama_index.legacy.embeddings.huggingface_utils import (
10
    format_query,
11
    format_text,
12
    get_pooling_mode,
13
)
14
from llama_index.legacy.embeddings.pooling import Pooling
15
from llama_index.legacy.utils import infer_torch_device
16

17

18
class OptimumEmbedding(BaseEmbedding):
19
    folder_name: str = Field(description="Folder name to load from.")
20
    max_length: int = Field(description="Maximum length of input.")
21
    pooling: str = Field(description="Pooling strategy. One of ['cls', 'mean'].")
22
    normalize: str = Field(default=True, description="Normalize embeddings or not.")
23
    query_instruction: Optional[str] = Field(
24
        description="Instruction to prepend to query text."
25
    )
26
    text_instruction: Optional[str] = Field(
27
        description="Instruction to prepend to text."
28
    )
29
    cache_folder: Optional[str] = Field(
30
        description="Cache folder for huggingface files."
31
    )
32

33
    _model: Any = PrivateAttr()
34
    _tokenizer: Any = PrivateAttr()
35
    _device: Any = PrivateAttr()
36

37
    def __init__(
38
        self,
39
        folder_name: str,
40
        pooling: Optional[str] = None,
41
        max_length: Optional[int] = None,
42
        normalize: bool = True,
43
        query_instruction: Optional[str] = None,
44
        text_instruction: Optional[str] = None,
45
        model: Optional[Any] = None,
46
        tokenizer: Optional[Any] = None,
47
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
48
        callback_manager: Optional[CallbackManager] = None,
49
        device: Optional[str] = None,
50
    ):
51
        try:
52
            from optimum.onnxruntime import ORTModelForFeatureExtraction
53
            from transformers import AutoTokenizer
54
        except ImportError:
55
            raise ImportError(
56
                "OptimumEmbedding requires transformers to be installed.\n"
57
                "Please install transformers with "
58
                "`pip install transformers optimum[exporters]`."
59
            )
60

61
        self._model = model or ORTModelForFeatureExtraction.from_pretrained(folder_name)
62
        self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(folder_name)
63
        self._device = device or infer_torch_device()
64

65
        if max_length is None:
66
            try:
67
                max_length = int(self._model.config.max_position_embeddings)
68
            except Exception:
69
                raise ValueError(
70
                    "Unable to find max_length from model config. "
71
                    "Please provide max_length."
72
                )
73

74
        if not pooling:
75
            pooling = get_pooling_mode(model)
76
        try:
77
            pooling = Pooling(pooling)
78
        except ValueError as exc:
79
            raise NotImplementedError(
80
                f"Pooling {pooling} unsupported, please pick one in"
81
                f" {[p.value for p in Pooling]}."
82
            ) from exc
83

84
        super().__init__(
85
            embed_batch_size=embed_batch_size,
86
            callback_manager=callback_manager,
87
            folder_name=folder_name,
88
            max_length=max_length,
89
            pooling=pooling,
90
            normalize=normalize,
91
            query_instruction=query_instruction,
92
            text_instruction=text_instruction,
93
        )
94

95
    @classmethod
96
    def class_name(cls) -> str:
97
        return "OptimumEmbedding"
98

99
    @classmethod
100
    def create_and_save_optimum_model(
101
        cls,
102
        model_name_or_path: str,
103
        output_path: str,
104
        export_kwargs: Optional[dict] = None,
105
    ) -> None:
106
        try:
107
            from optimum.onnxruntime import ORTModelForFeatureExtraction
108
            from transformers import AutoTokenizer
109
        except ImportError:
110
            raise ImportError(
111
                "OptimumEmbedding requires transformers to be installed.\n"
112
                "Please install transformers with "
113
                "`pip install transformers optimum[exporters]`."
114
            )
115

116
        export_kwargs = export_kwargs or {}
117
        model = ORTModelForFeatureExtraction.from_pretrained(
118
            model_name_or_path, export=True, **export_kwargs
119
        )
120
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
121

122
        model.save_pretrained(output_path)
123
        tokenizer.save_pretrained(output_path)
124
        print(
125
            f"Saved optimum model to {output_path}. Use it with "
126
            f"`embed_model = OptimumEmbedding(folder_name='{output_path}')`."
127
        )
128

129
    def _mean_pooling(self, model_output: Any, attention_mask: Any) -> Any:
130
        """Mean Pooling - Take attention mask into account for correct averaging."""
131
        import torch
132

133
        # First element of model_output contains all token embeddings
134
        token_embeddings = model_output[0]
135
        input_mask_expanded = (
136
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
137
        )
138
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
139
            input_mask_expanded.sum(1), min=1e-9
140
        )
141

142
    def _cls_pooling(self, model_output: list) -> Any:
143
        """Use the CLS token as the pooling token."""
144
        return model_output[0][:, 0]
145

146
    def _embed(self, sentences: List[str]) -> List[List[float]]:
147
        """Embed sentences."""
148
        encoded_input = self._tokenizer(
149
            sentences,
150
            padding=True,
151
            max_length=self.max_length,
152
            truncation=True,
153
            return_tensors="pt",
154
        )
155

156
        # pop token_type_ids
157
        encoded_input.pop("token_type_ids", None)
158

159
        model_output = self._model(**encoded_input)
160

161
        if self.pooling == "cls":
162
            embeddings = self._cls_pooling(model_output)
163
        else:
164
            embeddings = self._mean_pooling(
165
                model_output, encoded_input["attention_mask"].to(self._device)
166
            )
167

168
        if self.normalize:
169
            import torch
170

171
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
172

173
        return embeddings.tolist()
174

175
    def _get_query_embedding(self, query: str) -> List[float]:
176
        """Get query embedding."""
177
        query = format_query(query, self.model_name, self.query_instruction)
178
        return self._embed([query])[0]
179

180
    async def _aget_query_embedding(self, query: str) -> List[float]:
181
        """Get query embedding async."""
182
        return self._get_query_embedding(query)
183

184
    async def _aget_text_embedding(self, text: str) -> List[float]:
185
        """Get text embedding async."""
186
        return self._get_text_embedding(text)
187

188
    def _get_text_embedding(self, text: str) -> List[float]:
189
        """Get text embedding."""
190
        text = format_text(text, self.model_name, self.text_instruction)
191
        return self._embed([text])[0]
192

193
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
194
        """Get text embeddings."""
195
        texts = [
196
            format_text(text, self.model_name, self.text_instruction) for text in texts
197
        ]
198
        return self._embed(texts)
199

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.