llama-index
189 строк · 6.7 Кб
1"""Base embeddings file."""
2
3import asyncio4from abc import abstractmethod5from typing import Coroutine, List, Tuple6
7from llama_index.legacy.callbacks.schema import CBEventType, EventPayload8from llama_index.legacy.core.embeddings.base import (9BaseEmbedding,10Embedding,11)
12from llama_index.legacy.schema import ImageType13from llama_index.legacy.utils import get_tqdm_iterable14
15
16class MultiModalEmbedding(BaseEmbedding):17"""Base class for Multi Modal embeddings."""18
19@abstractmethod20def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:21"""22Embed the input image synchronously.
23
24Subclasses should implement this method. Reference get_image_embedding's
25docstring for more information.
26"""
27
28@abstractmethod29async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding:30"""31Embed the input image asynchronously.
32
33Subclasses should implement this method. Reference get_image_embedding's
34docstring for more information.
35"""
36
37def get_image_embedding(self, img_file_path: ImageType) -> Embedding:38"""39Embed the input image.
40"""
41with self.callback_manager.event(42CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}43) as event:44image_embedding = self._get_image_embedding(img_file_path)45
46event.on_end(47payload={48EventPayload.CHUNKS: [img_file_path],49EventPayload.EMBEDDINGS: [image_embedding],50},51)52return image_embedding53
54async def aget_image_embedding(self, img_file_path: ImageType) -> Embedding:55"""Get image embedding."""56with self.callback_manager.event(57CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}58) as event:59image_embedding = await self._aget_image_embedding(img_file_path)60
61event.on_end(62payload={63EventPayload.CHUNKS: [img_file_path],64EventPayload.EMBEDDINGS: [image_embedding],65},66)67return image_embedding68
69def _get_image_embeddings(self, img_file_paths: List[ImageType]) -> List[Embedding]:70"""71Embed the input sequence of image synchronously.
72
73Subclasses can implement this method if batch queries are supported.
74"""
75# Default implementation just loops over _get_image_embedding76return [77self._get_image_embedding(img_file_path) for img_file_path in img_file_paths78]79
80async def _aget_image_embeddings(81self, img_file_paths: List[ImageType]82) -> List[Embedding]:83"""84Embed the input sequence of image asynchronously.
85
86Subclasses can implement this method if batch queries are supported.
87"""
88return await asyncio.gather(89*[90self._aget_image_embedding(img_file_path)91for img_file_path in img_file_paths92]93)94
95def get_image_embedding_batch(96self, img_file_paths: List[ImageType], show_progress: bool = False97) -> List[Embedding]:98"""Get a list of image embeddings, with batching."""99cur_batch: List[ImageType] = []100result_embeddings: List[Embedding] = []101
102queue_with_progress = enumerate(103get_tqdm_iterable(104img_file_paths, show_progress, "Generating image embeddings"105)106)107
108for idx, img_file_path in queue_with_progress:109cur_batch.append(img_file_path)110if (111idx == len(img_file_paths) - 1112or len(cur_batch) == self.embed_batch_size113):114# flush115with self.callback_manager.event(116CBEventType.EMBEDDING,117payload={EventPayload.SERIALIZED: self.to_dict()},118) as event:119embeddings = self._get_image_embeddings(cur_batch)120result_embeddings.extend(embeddings)121event.on_end(122payload={123EventPayload.CHUNKS: cur_batch,124EventPayload.EMBEDDINGS: embeddings,125},126)127cur_batch = []128
129return result_embeddings130
131async def aget_image_embedding_batch(132self, img_file_paths: List[ImageType], show_progress: bool = False133) -> List[Embedding]:134"""Asynchronously get a list of image embeddings, with batching."""135cur_batch: List[ImageType] = []136callback_payloads: List[Tuple[str, List[ImageType]]] = []137result_embeddings: List[Embedding] = []138embeddings_coroutines: List[Coroutine] = []139for idx, img_file_path in enumerate(img_file_paths):140cur_batch.append(img_file_path)141if (142idx == len(img_file_paths) - 1143or len(cur_batch) == self.embed_batch_size144):145# flush146event_id = self.callback_manager.on_event_start(147CBEventType.EMBEDDING,148payload={EventPayload.SERIALIZED: self.to_dict()},149)150callback_payloads.append((event_id, cur_batch))151embeddings_coroutines.append(self._aget_image_embeddings(cur_batch))152cur_batch = []153
154# flatten the results of asyncio.gather, which is a list of embeddings lists155nested_embeddings = []156if show_progress:157try:158from tqdm.auto import tqdm159
160nested_embeddings = [161await f162for f in tqdm(163asyncio.as_completed(embeddings_coroutines),164total=len(embeddings_coroutines),165desc="Generating image embeddings",166)167]168except ImportError:169nested_embeddings = await asyncio.gather(*embeddings_coroutines)170else:171nested_embeddings = await asyncio.gather(*embeddings_coroutines)172
173result_embeddings = [174embedding for embeddings in nested_embeddings for embedding in embeddings175]176
177for (event_id, image_batch), embeddings in zip(178callback_payloads, nested_embeddings179):180self.callback_manager.on_event_end(181CBEventType.EMBEDDING,182payload={183EventPayload.CHUNKS: image_batch,184EventPayload.EMBEDDINGS: embeddings,185},186event_id=event_id,187)188
189return result_embeddings190