llama-index

Форк
0
146 строк · 4.7 Кб
1
import logging
2
from typing import Any, List
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
6
from llama_index.legacy.core.embeddings.base import Embedding
7
from llama_index.legacy.embeddings.multi_modal_base import MultiModalEmbedding
8
from llama_index.legacy.schema import ImageType
9

10
logger = logging.getLogger(__name__)
11

12

13
AVAILABLE_CLIP_MODELS = (
14
    "RN50",
15
    "RN101",
16
    "RN50x4",
17
    "RN50x16",
18
    "RN50x64",
19
    "ViT-B/32",
20
    "ViT-B/16",
21
    "ViT-L/14",
22
    "ViT-L/14@336px",
23
)
24
DEFAULT_CLIP_MODEL = "ViT-B/32"
25

26

27
class ClipEmbedding(MultiModalEmbedding):
28
    """CLIP embedding models for encoding text and image for Multi-Modal purpose.
29

30
    This class provides an interface to generate embeddings using a model
31
    deployed in OpenAI CLIP. At the initialization it requires a model name
32
    of CLIP.
33

34
    Note:
35
        Requires `clip` package to be available in the PYTHONPATH. It can be installed with
36
        `pip install git+https://github.com/openai/CLIP.git`.
37
    """
38

39
    embed_batch_size: int = Field(default=DEFAULT_EMBED_BATCH_SIZE, gt=0)
40

41
    _clip: Any = PrivateAttr()
42
    _model: Any = PrivateAttr()
43
    _preprocess: Any = PrivateAttr()
44
    _device: Any = PrivateAttr()
45

46
    @classmethod
47
    def class_name(cls) -> str:
48
        return "ClipEmbedding"
49

50
    def __init__(
51
        self,
52
        *,
53
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
54
        model_name: str = DEFAULT_CLIP_MODEL,
55
        **kwargs: Any,
56
    ):
57
        """Initializes the ClipEmbedding class.
58

59
        During the initialization the `clip` package is imported.
60

61
        Args:
62
            embed_batch_size (int, optional): The batch size for embedding generation. Defaults to 10,
63
                must be > 0 and <= 100.
64
            model_name (str): The model name of Clip model.
65

66
        Raises:
67
            ImportError: If the `clip` package is not available in the PYTHONPATH.
68
            ValueError: If the model cannot be fetched from Open AI. or if the embed_batch_size
69
                is not in the range (0, 100].
70
        """
71
        if embed_batch_size <= 0:
72
            raise ValueError(f"Embed batch size {embed_batch_size}  must be > 0.")
73

74
        try:
75
            import clip
76
            import torch
77
        except ImportError:
78
            raise ImportError(
79
                "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch."
80
            )
81

82
        super().__init__(
83
            embed_batch_size=embed_batch_size, model_name=model_name, **kwargs
84
        )
85

86
        try:
87
            self._device = "cuda" if torch.cuda.is_available() else "cpu"
88
            if self.model_name not in AVAILABLE_CLIP_MODELS:
89
                raise ValueError(
90
                    f"Model name {self.model_name} is not available in CLIP."
91
                )
92
            self._model, self._preprocess = clip.load(
93
                self.model_name, device=self._device
94
            )
95

96
        except Exception as e:
97
            logger.error(f"Error while loading clip model.")
98
            raise ValueError("Unable to fetch the requested embeddings model") from e
99

100
    # TEXT EMBEDDINGS
101

102
    async def _aget_query_embedding(self, query: str) -> Embedding:
103
        return self._get_query_embedding(query)
104

105
    def _get_text_embedding(self, text: str) -> Embedding:
106
        return self._get_text_embeddings([text])[0]
107

108
    def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
109
        results = []
110
        for text in texts:
111
            try:
112
                import clip
113
            except ImportError:
114
                raise ImportError(
115
                    "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch."
116
                )
117
            text_embedding = self._model.encode_text(
118
                clip.tokenize(text).to(self._device)
119
            )
120
            results.append(text_embedding.tolist()[0])
121

122
        return results
123

124
    def _get_query_embedding(self, query: str) -> Embedding:
125
        return self._get_text_embedding(query)
126

127
    # IMAGE EMBEDDINGS
128

129
    async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding:
130
        return self._get_image_embedding(img_file_path)
131

132
    def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:
133
        try:
134
            import torch
135
            from PIL import Image
136
        except ImportError:
137
            raise ImportError(
138
                "ClipEmbedding requires `pip install torch` and `pip install pillow`."
139
            )
140
        with torch.no_grad():
141
            image = (
142
                self._preprocess(Image.open(img_file_path))
143
                .unsqueeze(0)
144
                .to(self._device)
145
            )
146
            return self._model.encode_image(image).tolist()[0]
147

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.