llama-index
163 строки · 6.1 Кб
1from enum import Enum2from typing import Any, List, Optional3
4from llama_index.legacy.bridge.pydantic import Field5from llama_index.legacy.callbacks import CallbackManager6from llama_index.legacy.core.embeddings.base import (7DEFAULT_EMBED_BATCH_SIZE,8BaseEmbedding,9)
10
11
12# Enums for validation and type safety
13class CohereAIModelName(str, Enum):14ENGLISH_V3 = "embed-english-v3.0"15ENGLISH_LIGHT_V3 = "embed-english-light-v3.0"16MULTILINGUAL_V3 = "embed-multilingual-v3.0"17MULTILINGUAL_LIGHT_V3 = "embed-multilingual-light-v3.0"18
19ENGLISH_V2 = "embed-english-v2.0"20ENGLISH_LIGHT_V2 = "embed-english-light-v2.0"21MULTILINGUAL_V2 = "embed-multilingual-v2.0"22
23
24class CohereAIInputType(str, Enum):25SEARCH_QUERY = "search_query"26SEARCH_DOCUMENT = "search_document"27CLASSIFICATION = "classification"28CLUSTERING = "clustering"29
30
31class CohereAITruncate(str, Enum):32START = "START"33END = "END"34NONE = "NONE"35
36
37# convenient shorthand
38CAMN = CohereAIModelName39CAIT = CohereAIInputType40CAT = CohereAITruncate41
42# This list would be used for model name and input type validation
43VALID_MODEL_INPUT_TYPES = [44(CAMN.ENGLISH_V3, CAIT.SEARCH_QUERY),45(CAMN.ENGLISH_LIGHT_V3, CAIT.SEARCH_QUERY),46(CAMN.MULTILINGUAL_V3, CAIT.SEARCH_QUERY),47(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.SEARCH_QUERY),48(CAMN.ENGLISH_V3, CAIT.SEARCH_DOCUMENT),49(CAMN.ENGLISH_LIGHT_V3, CAIT.SEARCH_DOCUMENT),50(CAMN.MULTILINGUAL_V3, CAIT.SEARCH_DOCUMENT),51(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.SEARCH_DOCUMENT),52(CAMN.ENGLISH_V3, CAIT.CLASSIFICATION),53(CAMN.ENGLISH_LIGHT_V3, CAIT.CLASSIFICATION),54(CAMN.MULTILINGUAL_V3, CAIT.CLASSIFICATION),55(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.CLASSIFICATION),56(CAMN.ENGLISH_V3, CAIT.CLUSTERING),57(CAMN.ENGLISH_LIGHT_V3, CAIT.CLUSTERING),58(CAMN.MULTILINGUAL_V3, CAIT.CLUSTERING),59(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.CLUSTERING),60(CAMN.ENGLISH_V2, None),61(CAMN.ENGLISH_LIGHT_V2, None),62(CAMN.MULTILINGUAL_V2, None),63]
64
65VALID_TRUNCATE_OPTIONS = [CAT.START, CAT.END, CAT.NONE]66
67
68# Assuming BaseEmbedding is a Pydantic model and handles its own initializations
69class CohereEmbedding(BaseEmbedding):70"""CohereEmbedding uses the Cohere API to generate embeddings for text."""71
72# Instance variables initialized via Pydantic's mechanism73cohere_client: Any = Field(description="CohereAI client")74truncate: str = Field(description="Truncation type - START/ END/ NONE")75input_type: Optional[str] = Field(description="Model Input type")76
77def __init__(78self,79cohere_api_key: Optional[str] = None,80model_name: str = "embed-english-v2.0",81truncate: str = "END",82input_type: Optional[str] = None,83embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,84callback_manager: Optional[CallbackManager] = None,85):86"""87A class representation for generating embeddings using the Cohere API.
88
89Args:
90cohere_client (Any): An instance of the Cohere client, which is used to communicate with the Cohere API.
91truncate (str): A string indicating the truncation strategy to be applied to input text. Possible values
92are 'START', 'END', or 'NONE'.
93input_type (Optional[str]): An optional string that specifies the type of input provided to the model.
94This is model-dependent and could be one of the following: 'search_query',
95'search_document', 'classification', or 'clustering'.
96model_name (str): The name of the model to be used for generating embeddings. The class ensures that
97this model is supported and that the input type provided is compatible with the model.
98"""
99# Attempt to import cohere. If it fails, raise an informative ImportError.100try:101import cohere102except ImportError:103raise ImportError(104"CohereEmbedding requires the 'cohere' package to be installed.\n"105"Please install it with `pip install cohere`."106)107# Validate model_name and input_type108if (model_name, input_type) not in VALID_MODEL_INPUT_TYPES:109raise ValueError(110f"{(model_name, input_type)} is not valid for model '{model_name}'"111)112
113if truncate not in VALID_TRUNCATE_OPTIONS:114raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")115
116super().__init__(117cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"),118cohere_api_key=cohere_api_key,119model_name=model_name,120truncate=truncate,121input_type=input_type,122embed_batch_size=embed_batch_size,123callback_manager=callback_manager,124)125
126@classmethod127def class_name(cls) -> str:128return "CohereEmbedding"129
130def _embed(self, texts: List[str]) -> List[List[float]]:131"""Embed sentences using Cohere."""132if self.input_type:133result = self.cohere_client.embed(134texts=texts,135input_type=self.input_type,136model=self.model_name,137truncate=self.truncate,138).embeddings139else:140result = self.cohere_client.embed(141texts=texts, model=self.model_name, truncate=self.truncate142).embeddings143return [list(map(float, e)) for e in result]144
145def _get_query_embedding(self, query: str) -> List[float]:146"""Get query embedding."""147return self._embed([query])[0]148
149async def _aget_query_embedding(self, query: str) -> List[float]:150"""Get query embedding async."""151return self._get_query_embedding(query)152
153def _get_text_embedding(self, text: str) -> List[float]:154"""Get text embedding."""155return self._embed([text])[0]156
157async def _aget_text_embedding(self, text: str) -> List[float]:158"""Get text embedding async."""159return self._get_text_embedding(text)160
161def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:162"""Get text embeddings."""163return self._embed(texts)164