llama-index

Форк
0
123 строки · 4.3 Кб
1
"""Gemini embeddings file."""
2

3
from typing import Any, List, Optional
4

5
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.core.embeddings.base import (
8
    DEFAULT_EMBED_BATCH_SIZE,
9
    BaseEmbedding,
10
)
11

12

13
class GeminiEmbedding(BaseEmbedding):
14
    """Google Gemini embeddings.
15

16
    Args:
17
        model_name (str): Model for embedding.
18
            Defaults to "models/embedding-001".
19

20
        api_key (Optional[str]): API key to access the model. Defaults to None.
21
        api_base (Optional[str]): API base to access the model. Defaults to Official Base.
22
        transport (Optional[str]): Transport to access the model.
23
    """
24

25
    _model: Any = PrivateAttr()
26
    title: Optional[str] = Field(
27
        default="",
28
        description="Title is only applicable for retrieval_document tasks, and is used to represent a document title. For other tasks, title is invalid.",
29
    )
30
    task_type: Optional[str] = Field(
31
        default="retrieval_document",
32
        description="The task for embedding model.",
33
    )
34

35
    def __init__(
36
        self,
37
        model_name: str = "models/embedding-001",
38
        task_type: Optional[str] = "retrieval_document",
39
        api_key: Optional[str] = None,
40
        api_base: Optional[str] = None,
41
        transport: Optional[str] = None,
42
        title: Optional[str] = None,
43
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
44
        callback_manager: Optional[CallbackManager] = None,
45
        **kwargs: Any,
46
    ):
47
        try:
48
            import google.generativeai as gemini
49
        except ImportError:
50
            raise ImportError(
51
                "google-generativeai package not found, install with"
52
                "'pip install google-generativeai'"
53
            )
54
        # API keys are optional. The API can be authorised via OAuth (detected
55
        # environmentally) or by the GOOGLE_API_KEY environment variable.
56
        config_params: Dict[str, Any] = {
57
            "api_key": api_key or os.getenv("GOOGLE_API_KEY"),
58
        }
59
        if api_base:
60
            config_params["client_options"] = {"api_endpoint": api_base}
61
        if transport:
62
            config_params["transport"] = transport
63
        # transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
64
        gemini.configure(**config_params)
65
        self._model = gemini
66

67
        super().__init__(
68
            model_name=model_name,
69
            embed_batch_size=embed_batch_size,
70
            callback_manager=callback_manager,
71
            **kwargs,
72
        )
73
        self.title = title
74
        self.task_type = task_type
75

76
    @classmethod
77
    def class_name(cls) -> str:
78
        return "GeminiEmbedding"
79

80
    def _get_query_embedding(self, query: str) -> List[float]:
81
        """Get query embedding."""
82
        return self._model.embed_content(
83
            model=self.model_name,
84
            content=query,
85
            title=self.title,
86
            task_type=self.task_type,
87
        )["embedding"]
88

89
    def _get_text_embedding(self, text: str) -> List[float]:
90
        """Get text embedding."""
91
        return self._model.embed_content(
92
            model=self.model_name,
93
            content=text,
94
            title=self.title,
95
            task_type=self.task_type,
96
        )["embedding"]
97

98
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
99
        """Get text embeddings."""
100
        return [
101
            self._model.embed_content(
102
                model=self.model_name,
103
                content=text,
104
                title=self.title,
105
                task_type=self.task_type,
106
            )["embedding"]
107
            for text in texts
108
        ]
109

110
    ### Async methods ###
111
    # need to wait async calls from Gemini side to be implemented.
112
    # Issue: https://github.com/google/generative-ai-python/issues/125
113
    async def _aget_query_embedding(self, query: str) -> List[float]:
114
        """The asynchronous version of _get_query_embedding."""
115
        return self._get_query_embedding(query)
116

117
    async def _aget_text_embedding(self, text: str) -> List[float]:
118
        """Asynchronously get text embedding."""
119
        return self._get_text_embedding(text)
120

121
    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
122
        """Asynchronously get text embeddings."""
123
        return self._get_text_embeddings(texts)
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.