llama-index

Форк
0
354 строки · 12.2 Кб
1
"""Base embeddings file."""
2

3
import asyncio
4
from abc import abstractmethod
5
from enum import Enum
6
from typing import Any, Callable, Coroutine, List, Optional, Tuple
7

8
import numpy as np
9

10
from llama_index.legacy.bridge.pydantic import Field, validator
11
from llama_index.legacy.callbacks.base import CallbackManager
12
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
13
from llama_index.legacy.constants import (
14
    DEFAULT_EMBED_BATCH_SIZE,
15
)
16
from llama_index.legacy.schema import BaseNode, MetadataMode, TransformComponent
17
from llama_index.legacy.utils import get_tqdm_iterable
18

19
# TODO: change to numpy array
20
Embedding = List[float]
21

22

23
class SimilarityMode(str, Enum):
24
    """Modes for similarity/distance."""
25

26
    DEFAULT = "cosine"
27
    DOT_PRODUCT = "dot_product"
28
    EUCLIDEAN = "euclidean"
29

30

31
def mean_agg(embeddings: List[Embedding]) -> Embedding:
32
    """Mean aggregation for embeddings."""
33
    return list(np.array(embeddings).mean(axis=0))
34

35

36
def similarity(
37
    embedding1: Embedding,
38
    embedding2: Embedding,
39
    mode: SimilarityMode = SimilarityMode.DEFAULT,
40
) -> float:
41
    """Get embedding similarity."""
42
    if mode == SimilarityMode.EUCLIDEAN:
43
        # Using -euclidean distance as similarity to achieve same ranking order
44
        return -float(np.linalg.norm(np.array(embedding1) - np.array(embedding2)))
45
    elif mode == SimilarityMode.DOT_PRODUCT:
46
        return np.dot(embedding1, embedding2)
47
    else:
48
        product = np.dot(embedding1, embedding2)
49
        norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
50
        return product / norm
51

52

53
class BaseEmbedding(TransformComponent):
54
    """Base class for embeddings."""
55

56
    model_name: str = Field(
57
        default="unknown", description="The name of the embedding model."
58
    )
59
    embed_batch_size: int = Field(
60
        default=DEFAULT_EMBED_BATCH_SIZE,
61
        description="The batch size for embedding calls.",
62
        gt=0,
63
        lte=2048,
64
    )
65
    callback_manager: CallbackManager = Field(
66
        default_factory=lambda: CallbackManager([]), exclude=True
67
    )
68

69
    class Config:
70
        arbitrary_types_allowed = True
71

72
    @validator("callback_manager", pre=True)
73
    def _validate_callback_manager(
74
        cls, v: Optional[CallbackManager]
75
    ) -> CallbackManager:
76
        if v is None:
77
            return CallbackManager([])
78
        return v
79

80
    @abstractmethod
81
    def _get_query_embedding(self, query: str) -> Embedding:
82
        """
83
        Embed the input query synchronously.
84

85
        Subclasses should implement this method. Reference get_query_embedding's
86
        docstring for more information.
87
        """
88

89
    @abstractmethod
90
    async def _aget_query_embedding(self, query: str) -> Embedding:
91
        """
92
        Embed the input query asynchronously.
93

94
        Subclasses should implement this method. Reference get_query_embedding's
95
        docstring for more information.
96
        """
97

98
    def get_query_embedding(self, query: str) -> Embedding:
99
        """
100
        Embed the input query.
101

102
        When embedding a query, depending on the model, a special instruction
103
        can be prepended to the raw query string. For example, "Represent the
104
        question for retrieving supporting documents: ". If you're curious,
105
        other examples of predefined instructions can be found in
106
        embeddings/huggingface_utils.py.
107
        """
108
        with self.callback_manager.event(
109
            CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
110
        ) as event:
111
            query_embedding = self._get_query_embedding(query)
112

113
            event.on_end(
114
                payload={
115
                    EventPayload.CHUNKS: [query],
116
                    EventPayload.EMBEDDINGS: [query_embedding],
117
                },
118
            )
119
        return query_embedding
120

121
    async def aget_query_embedding(self, query: str) -> Embedding:
122
        """Get query embedding."""
123
        with self.callback_manager.event(
124
            CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
125
        ) as event:
126
            query_embedding = await self._aget_query_embedding(query)
127

128
            event.on_end(
129
                payload={
130
                    EventPayload.CHUNKS: [query],
131
                    EventPayload.EMBEDDINGS: [query_embedding],
132
                },
133
            )
134
        return query_embedding
135

136
    def get_agg_embedding_from_queries(
137
        self,
138
        queries: List[str],
139
        agg_fn: Optional[Callable[..., Embedding]] = None,
140
    ) -> Embedding:
141
        """Get aggregated embedding from multiple queries."""
142
        query_embeddings = [self.get_query_embedding(query) for query in queries]
143
        agg_fn = agg_fn or mean_agg
144
        return agg_fn(query_embeddings)
145

146
    async def aget_agg_embedding_from_queries(
147
        self,
148
        queries: List[str],
149
        agg_fn: Optional[Callable[..., Embedding]] = None,
150
    ) -> Embedding:
151
        """Async get aggregated embedding from multiple queries."""
152
        query_embeddings = [await self.aget_query_embedding(query) for query in queries]
153
        agg_fn = agg_fn or mean_agg
154
        return agg_fn(query_embeddings)
155

156
    @abstractmethod
157
    def _get_text_embedding(self, text: str) -> Embedding:
158
        """
159
        Embed the input text synchronously.
160

161
        Subclasses should implement this method. Reference get_text_embedding's
162
        docstring for more information.
163
        """
164

165
    async def _aget_text_embedding(self, text: str) -> Embedding:
166
        """
167
        Embed the input text asynchronously.
168

169
        Subclasses can implement this method if there is a true async
170
        implementation. Reference get_text_embedding's docstring for more
171
        information.
172
        """
173
        # Default implementation just falls back on _get_text_embedding
174
        return self._get_text_embedding(text)
175

176
    def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
177
        """
178
        Embed the input sequence of text synchronously.
179

180
        Subclasses can implement this method if batch queries are supported.
181
        """
182
        # Default implementation just loops over _get_text_embedding
183
        return [self._get_text_embedding(text) for text in texts]
184

185
    async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
186
        """
187
        Embed the input sequence of text asynchronously.
188

189
        Subclasses can implement this method if batch queries are supported.
190
        """
191
        return await asyncio.gather(
192
            *[self._aget_text_embedding(text) for text in texts]
193
        )
194

195
    def get_text_embedding(self, text: str) -> Embedding:
196
        """
197
        Embed the input text.
198

199
        When embedding text, depending on the model, a special instruction
200
        can be prepended to the raw text string. For example, "Represent the
201
        document for retrieval: ". If you're curious, other examples of
202
        predefined instructions can be found in embeddings/huggingface_utils.py.
203
        """
204
        with self.callback_manager.event(
205
            CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
206
        ) as event:
207
            text_embedding = self._get_text_embedding(text)
208

209
            event.on_end(
210
                payload={
211
                    EventPayload.CHUNKS: [text],
212
                    EventPayload.EMBEDDINGS: [text_embedding],
213
                }
214
            )
215

216
        return text_embedding
217

218
    async def aget_text_embedding(self, text: str) -> Embedding:
219
        """Async get text embedding."""
220
        with self.callback_manager.event(
221
            CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
222
        ) as event:
223
            text_embedding = await self._aget_text_embedding(text)
224

225
            event.on_end(
226
                payload={
227
                    EventPayload.CHUNKS: [text],
228
                    EventPayload.EMBEDDINGS: [text_embedding],
229
                }
230
            )
231

232
        return text_embedding
233

234
    def get_text_embedding_batch(
235
        self,
236
        texts: List[str],
237
        show_progress: bool = False,
238
        **kwargs: Any,
239
    ) -> List[Embedding]:
240
        """Get a list of text embeddings, with batching."""
241
        cur_batch: List[str] = []
242
        result_embeddings: List[Embedding] = []
243

244
        queue_with_progress = enumerate(
245
            get_tqdm_iterable(texts, show_progress, "Generating embeddings")
246
        )
247

248
        for idx, text in queue_with_progress:
249
            cur_batch.append(text)
250
            if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
251
                # flush
252
                with self.callback_manager.event(
253
                    CBEventType.EMBEDDING,
254
                    payload={EventPayload.SERIALIZED: self.to_dict()},
255
                ) as event:
256
                    embeddings = self._get_text_embeddings(cur_batch)
257
                    result_embeddings.extend(embeddings)
258
                    event.on_end(
259
                        payload={
260
                            EventPayload.CHUNKS: cur_batch,
261
                            EventPayload.EMBEDDINGS: embeddings,
262
                        },
263
                    )
264
                cur_batch = []
265

266
        return result_embeddings
267

268
    async def aget_text_embedding_batch(
269
        self, texts: List[str], show_progress: bool = False
270
    ) -> List[Embedding]:
271
        """Asynchronously get a list of text embeddings, with batching."""
272
        cur_batch: List[str] = []
273
        callback_payloads: List[Tuple[str, List[str]]] = []
274
        result_embeddings: List[Embedding] = []
275
        embeddings_coroutines: List[Coroutine] = []
276
        for idx, text in enumerate(texts):
277
            cur_batch.append(text)
278
            if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
279
                # flush
280
                event_id = self.callback_manager.on_event_start(
281
                    CBEventType.EMBEDDING,
282
                    payload={EventPayload.SERIALIZED: self.to_dict()},
283
                )
284
                callback_payloads.append((event_id, cur_batch))
285
                embeddings_coroutines.append(self._aget_text_embeddings(cur_batch))
286
                cur_batch = []
287

288
        # flatten the results of asyncio.gather, which is a list of embeddings lists
289
        nested_embeddings = []
290
        if show_progress:
291
            try:
292
                from tqdm.auto import tqdm
293

294
                nested_embeddings = [
295
                    await f
296
                    for f in tqdm(
297
                        asyncio.as_completed(embeddings_coroutines),
298
                        total=len(embeddings_coroutines),
299
                        desc="Generating embeddings",
300
                    )
301
                ]
302
            except ImportError:
303
                nested_embeddings = await asyncio.gather(*embeddings_coroutines)
304
        else:
305
            nested_embeddings = await asyncio.gather(*embeddings_coroutines)
306

307
        result_embeddings = [
308
            embedding for embeddings in nested_embeddings for embedding in embeddings
309
        ]
310

311
        for (event_id, text_batch), embeddings in zip(
312
            callback_payloads, nested_embeddings
313
        ):
314
            self.callback_manager.on_event_end(
315
                CBEventType.EMBEDDING,
316
                payload={
317
                    EventPayload.CHUNKS: text_batch,
318
                    EventPayload.EMBEDDINGS: embeddings,
319
                },
320
                event_id=event_id,
321
            )
322

323
        return result_embeddings
324

325
    def similarity(
326
        self,
327
        embedding1: Embedding,
328
        embedding2: Embedding,
329
        mode: SimilarityMode = SimilarityMode.DEFAULT,
330
    ) -> float:
331
        """Get embedding similarity."""
332
        return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode)
333

334
    def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
335
        embeddings = self.get_text_embedding_batch(
336
            [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes],
337
            **kwargs,
338
        )
339

340
        for node, embedding in zip(nodes, embeddings):
341
            node.embedding = embedding
342

343
        return nodes
344

345
    async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
346
        embeddings = await self.aget_text_embedding_batch(
347
            [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes],
348
            **kwargs,
349
        )
350

351
        for node, embedding in zip(nodes, embeddings):
352
            node.embedding = embedding
353

354
        return nodes
355

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.