llama-index

Форк
0
82 строки · 2.6 Кб
1
"""Google PaLM embeddings file."""
2

3
from typing import Any, List, Optional
4

5
from llama_index.legacy.bridge.pydantic import PrivateAttr
6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.core.embeddings.base import (
8
    DEFAULT_EMBED_BATCH_SIZE,
9
    BaseEmbedding,
10
)
11

12

13
class GooglePaLMEmbedding(BaseEmbedding):
14
    """Class for Google PaLM embeddings.
15

16
    Args:
17
        model_name (str): Model for embedding.
18
            Defaults to "models/embedding-gecko-001".
19

20
        api_key (Optional[str]): API key to access the model. Defaults to None.
21
    """
22

23
    _model: Any = PrivateAttr()
24

25
    def __init__(
26
        self,
27
        model_name: str = "models/embedding-gecko-001",
28
        api_key: Optional[str] = None,
29
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
30
        callback_manager: Optional[CallbackManager] = None,
31
        **kwargs: Any,
32
    ):
33
        try:
34
            import google.generativeai as palm
35
        except ImportError:
36
            raise ImportError(
37
                "google-generativeai package not found, install with"
38
                "'pip install google-generativeai'"
39
            )
40
        palm.configure(api_key=api_key)
41
        self._model = palm
42

43
        super().__init__(
44
            model_name=model_name,
45
            embed_batch_size=embed_batch_size,
46
            callback_manager=callback_manager,
47
            **kwargs,
48
        )
49

50
    @classmethod
51
    def class_name(cls) -> str:
52
        return "PaLMEmbedding"
53

54
    def _get_query_embedding(self, query: str) -> List[float]:
55
        """Get query embedding."""
56
        return self._model.generate_embeddings(model=self.model_name, text=query)[
57
            "embedding"
58
        ]
59

60
    async def _aget_query_embedding(self, query: str) -> List[float]:
61
        """The asynchronous version of _get_query_embedding."""
62
        return await self._model.aget_embedding(query)
63

64
    def _get_text_embedding(self, text: str) -> List[float]:
65
        """Get text embedding."""
66
        return self._model.generate_embeddings(model=self.model_name, text=text)[
67
            "embedding"
68
        ]
69

70
    async def _aget_text_embedding(self, text: str) -> List[float]:
71
        """Asynchronously get text embedding."""
72
        return self._model._get_text_embedding(text)
73

74
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
75
        """Get text embeddings."""
76
        return self._model.generate_embeddings(model=self.model_name, text=texts)[
77
            "embedding"
78
        ]
79

80
    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
81
        """Asynchronously get text embeddings."""
82
        return await self._model._get_embeddings(texts)
83

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.