llama-index

Форк
0
189 строк · 6.7 Кб
1
"""Base embeddings file."""
2

3
import asyncio
4
from abc import abstractmethod
5
from typing import Coroutine, List, Tuple
6

7
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
8
from llama_index.legacy.core.embeddings.base import (
9
    BaseEmbedding,
10
    Embedding,
11
)
12
from llama_index.legacy.schema import ImageType
13
from llama_index.legacy.utils import get_tqdm_iterable
14

15

16
class MultiModalEmbedding(BaseEmbedding):
17
    """Base class for Multi Modal embeddings."""
18

19
    @abstractmethod
20
    def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:
21
        """
22
        Embed the input image synchronously.
23

24
        Subclasses should implement this method. Reference get_image_embedding's
25
        docstring for more information.
26
        """
27

28
    @abstractmethod
29
    async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding:
30
        """
31
        Embed the input image asynchronously.
32

33
        Subclasses should implement this method. Reference get_image_embedding's
34
        docstring for more information.
35
        """
36

37
    def get_image_embedding(self, img_file_path: ImageType) -> Embedding:
38
        """
39
        Embed the input image.
40
        """
41
        with self.callback_manager.event(
42
            CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
43
        ) as event:
44
            image_embedding = self._get_image_embedding(img_file_path)
45

46
            event.on_end(
47
                payload={
48
                    EventPayload.CHUNKS: [img_file_path],
49
                    EventPayload.EMBEDDINGS: [image_embedding],
50
                },
51
            )
52
        return image_embedding
53

54
    async def aget_image_embedding(self, img_file_path: ImageType) -> Embedding:
55
        """Get image embedding."""
56
        with self.callback_manager.event(
57
            CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
58
        ) as event:
59
            image_embedding = await self._aget_image_embedding(img_file_path)
60

61
            event.on_end(
62
                payload={
63
                    EventPayload.CHUNKS: [img_file_path],
64
                    EventPayload.EMBEDDINGS: [image_embedding],
65
                },
66
            )
67
        return image_embedding
68

69
    def _get_image_embeddings(self, img_file_paths: List[ImageType]) -> List[Embedding]:
70
        """
71
        Embed the input sequence of image synchronously.
72

73
        Subclasses can implement this method if batch queries are supported.
74
        """
75
        # Default implementation just loops over _get_image_embedding
76
        return [
77
            self._get_image_embedding(img_file_path) for img_file_path in img_file_paths
78
        ]
79

80
    async def _aget_image_embeddings(
81
        self, img_file_paths: List[ImageType]
82
    ) -> List[Embedding]:
83
        """
84
        Embed the input sequence of image asynchronously.
85

86
        Subclasses can implement this method if batch queries are supported.
87
        """
88
        return await asyncio.gather(
89
            *[
90
                self._aget_image_embedding(img_file_path)
91
                for img_file_path in img_file_paths
92
            ]
93
        )
94

95
    def get_image_embedding_batch(
96
        self, img_file_paths: List[ImageType], show_progress: bool = False
97
    ) -> List[Embedding]:
98
        """Get a list of image embeddings, with batching."""
99
        cur_batch: List[ImageType] = []
100
        result_embeddings: List[Embedding] = []
101

102
        queue_with_progress = enumerate(
103
            get_tqdm_iterable(
104
                img_file_paths, show_progress, "Generating image embeddings"
105
            )
106
        )
107

108
        for idx, img_file_path in queue_with_progress:
109
            cur_batch.append(img_file_path)
110
            if (
111
                idx == len(img_file_paths) - 1
112
                or len(cur_batch) == self.embed_batch_size
113
            ):
114
                # flush
115
                with self.callback_manager.event(
116
                    CBEventType.EMBEDDING,
117
                    payload={EventPayload.SERIALIZED: self.to_dict()},
118
                ) as event:
119
                    embeddings = self._get_image_embeddings(cur_batch)
120
                    result_embeddings.extend(embeddings)
121
                    event.on_end(
122
                        payload={
123
                            EventPayload.CHUNKS: cur_batch,
124
                            EventPayload.EMBEDDINGS: embeddings,
125
                        },
126
                    )
127
                cur_batch = []
128

129
        return result_embeddings
130

131
    async def aget_image_embedding_batch(
132
        self, img_file_paths: List[ImageType], show_progress: bool = False
133
    ) -> List[Embedding]:
134
        """Asynchronously get a list of image embeddings, with batching."""
135
        cur_batch: List[ImageType] = []
136
        callback_payloads: List[Tuple[str, List[ImageType]]] = []
137
        result_embeddings: List[Embedding] = []
138
        embeddings_coroutines: List[Coroutine] = []
139
        for idx, img_file_path in enumerate(img_file_paths):
140
            cur_batch.append(img_file_path)
141
            if (
142
                idx == len(img_file_paths) - 1
143
                or len(cur_batch) == self.embed_batch_size
144
            ):
145
                # flush
146
                event_id = self.callback_manager.on_event_start(
147
                    CBEventType.EMBEDDING,
148
                    payload={EventPayload.SERIALIZED: self.to_dict()},
149
                )
150
                callback_payloads.append((event_id, cur_batch))
151
                embeddings_coroutines.append(self._aget_image_embeddings(cur_batch))
152
                cur_batch = []
153

154
        # flatten the results of asyncio.gather, which is a list of embeddings lists
155
        nested_embeddings = []
156
        if show_progress:
157
            try:
158
                from tqdm.auto import tqdm
159

160
                nested_embeddings = [
161
                    await f
162
                    for f in tqdm(
163
                        asyncio.as_completed(embeddings_coroutines),
164
                        total=len(embeddings_coroutines),
165
                        desc="Generating image embeddings",
166
                    )
167
                ]
168
            except ImportError:
169
                nested_embeddings = await asyncio.gather(*embeddings_coroutines)
170
        else:
171
            nested_embeddings = await asyncio.gather(*embeddings_coroutines)
172

173
        result_embeddings = [
174
            embedding for embeddings in nested_embeddings for embedding in embeddings
175
        ]
176

177
        for (event_id, image_batch), embeddings in zip(
178
            callback_payloads, nested_embeddings
179
        ):
180
            self.callback_manager.on_event_end(
181
                CBEventType.EMBEDDING,
182
                payload={
183
                    EventPayload.CHUNKS: image_batch,
184
                    EventPayload.EMBEDDINGS: embeddings,
185
                },
186
                event_id=event_id,
187
            )
188

189
        return result_embeddings
190

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.