llama-index

Форк
0
141 строка · 5.1 Кб
1
import logging
2
from typing import Any, List, Optional
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
7
from llama_index.legacy.core.embeddings.base import BaseEmbedding
8

9
logger = logging.getLogger(__name__)
10

11
EXAMPLE_URL = "https://clarifai.com/anthropic/completion/models/claude-v2"
12

13

14
class ClarifaiEmbedding(BaseEmbedding):
15
    """Clarifai embeddings class.
16

17
    Clarifai uses Personal Access Tokens(PAT) to validate requests.
18
    You can create and manage PATs under your Clarifai account security settings.
19
    Export your PAT as an environment variable by running `export CLARIFAI_PAT={PAT}`
20
    """
21

22
    model_url: Optional[str] = Field(
23
        description=f"Full URL of the model. e.g. `{EXAMPLE_URL}`"
24
    )
25
    model_id: Optional[str] = Field(description="Model ID.")
26
    model_version_id: Optional[str] = Field(description="Model Version ID.")
27
    app_id: Optional[str] = Field(description="Clarifai application ID of the model.")
28
    user_id: Optional[str] = Field(description="Clarifai user ID of the model.")
29
    pat: Optional[str] = Field(
30
        description="Personal Access Tokens(PAT) to validate requests."
31
    )
32

33
    _model: Any = PrivateAttr()
34

35
    def __init__(
36
        self,
37
        model_name: Optional[str] = None,
38
        model_url: Optional[str] = None,
39
        model_version_id: Optional[str] = "",
40
        app_id: Optional[str] = None,
41
        user_id: Optional[str] = None,
42
        pat: Optional[str] = None,
43
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
44
        callback_manager: Optional[CallbackManager] = None,
45
    ):
46
        try:
47
            import os
48

49
            from clarifai.client.model import Model
50
        except ImportError:
51
            raise ImportError("ClarifaiEmbedding requires `pip install clarifai`.")
52

53
        embed_batch_size = min(128, embed_batch_size)
54

55
        if pat is None and os.environ.get("CLARIFAI_PAT") is not None:
56
            pat = os.environ.get("CLARIFAI_PAT")
57

58
        if not pat and os.environ.get("CLARIFAI_PAT") is None:
59
            raise ValueError(
60
                "Set `CLARIFAI_PAT` as env variable or pass `pat` as constructor argument"
61
            )
62

63
        if model_url is not None and model_name is not None:
64
            raise ValueError("You can only specify one of model_url or model_name.")
65
        if model_url is None and model_name is None:
66
            raise ValueError("You must specify one of model_url or model_name.")
67

68
        if model_name is not None:
69
            if app_id is None or user_id is None:
70
                raise ValueError(
71
                    f"Missing one app ID or user ID of the model: {app_id=}, {user_id=}"
72
                )
73
            else:
74
                self._model = Model(
75
                    user_id=user_id,
76
                    app_id=app_id,
77
                    model_id=model_name,
78
                    model_version={"id": model_version_id},
79
                    pat=pat,
80
                )
81

82
        if model_url is not None:
83
            self._model = Model(model_url, pat=pat)
84
            model_name = self._model.id
85

86
        super().__init__(
87
            embed_batch_size=embed_batch_size,
88
            callback_manager=callback_manager,
89
            model_name=model_name,
90
        )
91

92
    @classmethod
93
    def class_name(cls) -> str:
94
        return "ClarifaiEmbedding"
95

96
    def _embed(self, sentences: List[str]) -> List[List[float]]:
97
        """Embed sentences."""
98
        try:
99
            from clarifai.client.input import Inputs
100
        except ImportError:
101
            raise ImportError("ClarifaiEmbedding requires `pip install clarifai`.")
102

103
        embeddings = []
104
        try:
105
            for i in range(0, len(sentences), self.embed_batch_size):
106
                batch = sentences[i : i + self.embed_batch_size]
107
                input_batch = [
108
                    Inputs.get_text_input(input_id=str(id), raw_text=inp)
109
                    for id, inp in enumerate(batch)
110
                ]
111
                predict_response = self._model.predict(input_batch)
112
                embeddings.extend(
113
                    [
114
                        list(output.data.embeddings[0].vector)
115
                        for output in predict_response.outputs
116
                    ]
117
                )
118
        except Exception as e:
119
            logger.error(f"Predict failed, exception: {e}")
120

121
        return embeddings
122

123
    def _get_query_embedding(self, query: str) -> List[float]:
124
        """Get query embedding."""
125
        return self._embed([query])[0]
126

127
    async def _aget_query_embedding(self, query: str) -> List[float]:
128
        """Get query embedding async."""
129
        return self._get_query_embedding(query)
130

131
    async def _aget_text_embedding(self, text: str) -> List[float]:
132
        """Get text embedding async."""
133
        return self._get_text_embedding(text)
134

135
    def _get_text_embedding(self, text: str) -> List[float]:
136
        """Get text embedding."""
137
        return self._embed([text])[0]
138

139
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
140
        """Get text embeddings."""
141
        return self._embed(texts)
142

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.