llama-index

Форк
0
115 строк · 3.8 Кб
1
"""MistralAI embeddings file."""
2

3
from typing import Any, List, Optional
4

5
from llama_index.legacy.bridge.pydantic import PrivateAttr
6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.core.embeddings.base import (
8
    DEFAULT_EMBED_BATCH_SIZE,
9
    BaseEmbedding,
10
)
11
from llama_index.legacy.llms.generic_utils import get_from_param_or_env
12

13

14
class MistralAIEmbedding(BaseEmbedding):
15
    """Class for MistralAI embeddings.
16

17
    Args:
18
        model_name (str): Model for embedding.
19
            Defaults to "mistral-embed".
20

21
        api_key (Optional[str]): API key to access the model. Defaults to None.
22
    """
23

24
    # Instance variables initialized via Pydantic's mechanism
25
    _mistralai_client: Any = PrivateAttr()
26
    _mistralai_async_client: Any = PrivateAttr()
27

28
    def __init__(
29
        self,
30
        model_name: str = "mistral-embed",
31
        api_key: Optional[str] = None,
32
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
33
        callback_manager: Optional[CallbackManager] = None,
34
        **kwargs: Any,
35
    ):
36
        try:
37
            from mistralai.async_client import MistralAsyncClient
38
            from mistralai.client import MistralClient
39
        except ImportError:
40
            raise ImportError(
41
                "mistralai package not found, install with" "'pip install mistralai'"
42
            )
43
        api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")
44

45
        if not api_key:
46
            raise ValueError(
47
                "You must provide an API key to use mistralai. "
48
                "You can either pass it in as an argument or set it `MISTRAL_API_KEY`."
49
            )
50
        self._mistralai_client = MistralClient(api_key=api_key)
51
        self._mistralai_async_client = MistralAsyncClient(api_key=api_key)
52
        super().__init__(
53
            model_name=model_name,
54
            embed_batch_size=embed_batch_size,
55
            callback_manager=callback_manager,
56
            **kwargs,
57
        )
58

59
    @classmethod
60
    def class_name(cls) -> str:
61
        return "MistralAIEmbedding"
62

63
    def _get_query_embedding(self, query: str) -> List[float]:
64
        """Get query embedding."""
65
        return (
66
            self._mistralai_client.embeddings(model=self.model_name, input=[query])
67
            .data[0]
68
            .embedding
69
        )
70

71
    async def _aget_query_embedding(self, query: str) -> List[float]:
72
        """The asynchronous version of _get_query_embedding."""
73
        return (
74
            (
75
                await self._mistralai_async_client.embeddings(
76
                    model=self.model_name, input=[query]
77
                )
78
            )
79
            .data[0]
80
            .embedding
81
        )
82

83
    def _get_text_embedding(self, text: str) -> List[float]:
84
        """Get text embedding."""
85
        return (
86
            self._mistralai_client.embeddings(model=self.model_name, input=[text])
87
            .data[0]
88
            .embedding
89
        )
90

91
    async def _aget_text_embedding(self, text: str) -> List[float]:
92
        """Asynchronously get text embedding."""
93
        return (
94
            (
95
                await self._mistralai_async_client.embeddings(
96
                    model=self.model_name, input=[text]
97
                )
98
            )
99
            .data[0]
100
            .embedding
101
        )
102

103
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
104
        """Get text embeddings."""
105
        embedding_response = self._mistralai_client.embeddings(
106
            model=self.model_name, input=texts
107
        ).data
108
        return [embed.embedding for embed in embedding_response]
109

110
    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
111
        """Asynchronously get text embeddings."""
112
        embedding_response = await self._mistralai_async_client.embeddings(
113
            model=self.model_name, input=texts
114
        )
115
        return [embed.embedding for embed in embedding_response.data]
116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.