llama-index

Форк
0
163 строки · 6.1 Кб
1
from enum import Enum
2
from typing import Any, List, Optional
3

4
from llama_index.legacy.bridge.pydantic import Field
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.core.embeddings.base import (
7
    DEFAULT_EMBED_BATCH_SIZE,
8
    BaseEmbedding,
9
)
10

11

12
# Enums for validation and type safety
13
class CohereAIModelName(str, Enum):
14
    ENGLISH_V3 = "embed-english-v3.0"
15
    ENGLISH_LIGHT_V3 = "embed-english-light-v3.0"
16
    MULTILINGUAL_V3 = "embed-multilingual-v3.0"
17
    MULTILINGUAL_LIGHT_V3 = "embed-multilingual-light-v3.0"
18

19
    ENGLISH_V2 = "embed-english-v2.0"
20
    ENGLISH_LIGHT_V2 = "embed-english-light-v2.0"
21
    MULTILINGUAL_V2 = "embed-multilingual-v2.0"
22

23

24
class CohereAIInputType(str, Enum):
25
    SEARCH_QUERY = "search_query"
26
    SEARCH_DOCUMENT = "search_document"
27
    CLASSIFICATION = "classification"
28
    CLUSTERING = "clustering"
29

30

31
class CohereAITruncate(str, Enum):
32
    START = "START"
33
    END = "END"
34
    NONE = "NONE"
35

36

37
# convenient shorthand
38
CAMN = CohereAIModelName
39
CAIT = CohereAIInputType
40
CAT = CohereAITruncate
41

42
# This list would be used for model name and input type validation
43
VALID_MODEL_INPUT_TYPES = [
44
    (CAMN.ENGLISH_V3, CAIT.SEARCH_QUERY),
45
    (CAMN.ENGLISH_LIGHT_V3, CAIT.SEARCH_QUERY),
46
    (CAMN.MULTILINGUAL_V3, CAIT.SEARCH_QUERY),
47
    (CAMN.MULTILINGUAL_LIGHT_V3, CAIT.SEARCH_QUERY),
48
    (CAMN.ENGLISH_V3, CAIT.SEARCH_DOCUMENT),
49
    (CAMN.ENGLISH_LIGHT_V3, CAIT.SEARCH_DOCUMENT),
50
    (CAMN.MULTILINGUAL_V3, CAIT.SEARCH_DOCUMENT),
51
    (CAMN.MULTILINGUAL_LIGHT_V3, CAIT.SEARCH_DOCUMENT),
52
    (CAMN.ENGLISH_V3, CAIT.CLASSIFICATION),
53
    (CAMN.ENGLISH_LIGHT_V3, CAIT.CLASSIFICATION),
54
    (CAMN.MULTILINGUAL_V3, CAIT.CLASSIFICATION),
55
    (CAMN.MULTILINGUAL_LIGHT_V3, CAIT.CLASSIFICATION),
56
    (CAMN.ENGLISH_V3, CAIT.CLUSTERING),
57
    (CAMN.ENGLISH_LIGHT_V3, CAIT.CLUSTERING),
58
    (CAMN.MULTILINGUAL_V3, CAIT.CLUSTERING),
59
    (CAMN.MULTILINGUAL_LIGHT_V3, CAIT.CLUSTERING),
60
    (CAMN.ENGLISH_V2, None),
61
    (CAMN.ENGLISH_LIGHT_V2, None),
62
    (CAMN.MULTILINGUAL_V2, None),
63
]
64

65
VALID_TRUNCATE_OPTIONS = [CAT.START, CAT.END, CAT.NONE]
66

67

68
# Assuming BaseEmbedding is a Pydantic model and handles its own initializations
69
class CohereEmbedding(BaseEmbedding):
70
    """CohereEmbedding uses the Cohere API to generate embeddings for text."""
71

72
    # Instance variables initialized via Pydantic's mechanism
73
    cohere_client: Any = Field(description="CohereAI client")
74
    truncate: str = Field(description="Truncation type - START/ END/ NONE")
75
    input_type: Optional[str] = Field(description="Model Input type")
76

77
    def __init__(
78
        self,
79
        cohere_api_key: Optional[str] = None,
80
        model_name: str = "embed-english-v2.0",
81
        truncate: str = "END",
82
        input_type: Optional[str] = None,
83
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
84
        callback_manager: Optional[CallbackManager] = None,
85
    ):
86
        """
87
        A class representation for generating embeddings using the Cohere API.
88

89
        Args:
90
            cohere_client (Any): An instance of the Cohere client, which is used to communicate with the Cohere API.
91
            truncate (str): A string indicating the truncation strategy to be applied to input text. Possible values
92
                        are 'START', 'END', or 'NONE'.
93
            input_type (Optional[str]): An optional string that specifies the type of input provided to the model.
94
                                    This is model-dependent and could be one of the following: 'search_query',
95
                                    'search_document', 'classification', or 'clustering'.
96
            model_name (str): The name of the model to be used for generating embeddings. The class ensures that
97
                          this model is supported and that the input type provided is compatible with the model.
98
        """
99
        # Attempt to import cohere. If it fails, raise an informative ImportError.
100
        try:
101
            import cohere
102
        except ImportError:
103
            raise ImportError(
104
                "CohereEmbedding requires the 'cohere' package to be installed.\n"
105
                "Please install it with `pip install cohere`."
106
            )
107
        # Validate model_name and input_type
108
        if (model_name, input_type) not in VALID_MODEL_INPUT_TYPES:
109
            raise ValueError(
110
                f"{(model_name, input_type)} is not valid for model '{model_name}'"
111
            )
112

113
        if truncate not in VALID_TRUNCATE_OPTIONS:
114
            raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")
115

116
        super().__init__(
117
            cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"),
118
            cohere_api_key=cohere_api_key,
119
            model_name=model_name,
120
            truncate=truncate,
121
            input_type=input_type,
122
            embed_batch_size=embed_batch_size,
123
            callback_manager=callback_manager,
124
        )
125

126
    @classmethod
127
    def class_name(cls) -> str:
128
        return "CohereEmbedding"
129

130
    def _embed(self, texts: List[str]) -> List[List[float]]:
131
        """Embed sentences using Cohere."""
132
        if self.input_type:
133
            result = self.cohere_client.embed(
134
                texts=texts,
135
                input_type=self.input_type,
136
                model=self.model_name,
137
                truncate=self.truncate,
138
            ).embeddings
139
        else:
140
            result = self.cohere_client.embed(
141
                texts=texts, model=self.model_name, truncate=self.truncate
142
            ).embeddings
143
        return [list(map(float, e)) for e in result]
144

145
    def _get_query_embedding(self, query: str) -> List[float]:
146
        """Get query embedding."""
147
        return self._embed([query])[0]
148

149
    async def _aget_query_embedding(self, query: str) -> List[float]:
150
        """Get query embedding async."""
151
        return self._get_query_embedding(query)
152

153
    def _get_text_embedding(self, text: str) -> List[float]:
154
        """Get text embedding."""
155
        return self._embed([text])[0]
156

157
    async def _aget_text_embedding(self, text: str) -> List[float]:
158
        """Get text embedding async."""
159
        return self._get_text_embedding(text)
160

161
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
162
        """Get text embeddings."""
163
        return self._embed(texts)
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.