llama-index
115 строк · 3.8 Кб
1"""MistralAI embeddings file."""
2
3from typing import Any, List, Optional4
5from llama_index.legacy.bridge.pydantic import PrivateAttr6from llama_index.legacy.callbacks.base import CallbackManager7from llama_index.legacy.core.embeddings.base import (8DEFAULT_EMBED_BATCH_SIZE,9BaseEmbedding,10)
11from llama_index.legacy.llms.generic_utils import get_from_param_or_env12
13
14class MistralAIEmbedding(BaseEmbedding):15"""Class for MistralAI embeddings.16
17Args:
18model_name (str): Model for embedding.
19Defaults to "mistral-embed".
20
21api_key (Optional[str]): API key to access the model. Defaults to None.
22"""
23
24# Instance variables initialized via Pydantic's mechanism25_mistralai_client: Any = PrivateAttr()26_mistralai_async_client: Any = PrivateAttr()27
28def __init__(29self,30model_name: str = "mistral-embed",31api_key: Optional[str] = None,32embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,33callback_manager: Optional[CallbackManager] = None,34**kwargs: Any,35):36try:37from mistralai.async_client import MistralAsyncClient38from mistralai.client import MistralClient39except ImportError:40raise ImportError(41"mistralai package not found, install with" "'pip install mistralai'"42)43api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")44
45if not api_key:46raise ValueError(47"You must provide an API key to use mistralai. "48"You can either pass it in as an argument or set it `MISTRAL_API_KEY`."49)50self._mistralai_client = MistralClient(api_key=api_key)51self._mistralai_async_client = MistralAsyncClient(api_key=api_key)52super().__init__(53model_name=model_name,54embed_batch_size=embed_batch_size,55callback_manager=callback_manager,56**kwargs,57)58
59@classmethod60def class_name(cls) -> str:61return "MistralAIEmbedding"62
63def _get_query_embedding(self, query: str) -> List[float]:64"""Get query embedding."""65return (66self._mistralai_client.embeddings(model=self.model_name, input=[query])67.data[0]68.embedding69)70
71async def _aget_query_embedding(self, query: str) -> List[float]:72"""The asynchronous version of _get_query_embedding."""73return (74(75await self._mistralai_async_client.embeddings(76model=self.model_name, input=[query]77)78)79.data[0]80.embedding81)82
83def _get_text_embedding(self, text: str) -> List[float]:84"""Get text embedding."""85return (86self._mistralai_client.embeddings(model=self.model_name, input=[text])87.data[0]88.embedding89)90
91async def _aget_text_embedding(self, text: str) -> List[float]:92"""Asynchronously get text embedding."""93return (94(95await self._mistralai_async_client.embeddings(96model=self.model_name, input=[text]97)98)99.data[0]100.embedding101)102
103def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:104"""Get text embeddings."""105embedding_response = self._mistralai_client.embeddings(106model=self.model_name, input=texts107).data108return [embed.embedding for embed in embedding_response]109
110async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:111"""Asynchronously get text embeddings."""112embedding_response = await self._mistralai_async_client.embeddings(113model=self.model_name, input=texts114)115return [embed.embedding for embed in embedding_response.data]116