llama-index

Форк
0
67 строк · 2.1 Кб
1
"""Google Universal Sentence Encoder Embedding Wrapper Module."""
2

3
from typing import Any, List, Optional
4

5
from llama_index.legacy.bridge.pydantic import PrivateAttr
6
from llama_index.legacy.callbacks import CallbackManager
7
from llama_index.legacy.core.embeddings.base import (
8
    DEFAULT_EMBED_BATCH_SIZE,
9
    BaseEmbedding,
10
)
11

12
# Google Universal Sentence Encode v5
13
DEFAULT_HANDLE = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
14

15

16
class GoogleUnivSentEncoderEmbedding(BaseEmbedding):
17
    _model: Any = PrivateAttr()
18

19
    def __init__(
20
        self,
21
        handle: Optional[str] = None,
22
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
23
        callback_manager: Optional[CallbackManager] = None,
24
    ):
25
        """Init params."""
26
        handle = handle or DEFAULT_HANDLE
27
        try:
28
            import tensorflow_hub as hub
29

30
            model = hub.load(handle)
31
        except ImportError:
32
            raise ImportError(
33
                "Please install tensorflow_hub: `pip install tensorflow_hub`"
34
            )
35

36
        self._model = model
37
        super().__init__(
38
            embed_batch_size=embed_batch_size,
39
            callback_manager=callback_manager,
40
            model_name=handle,
41
        )
42

43
    @classmethod
44
    def class_name(cls) -> str:
45
        return "GoogleUnivSentEncoderEmbedding"
46

47
    def _get_query_embedding(self, query: str) -> List[float]:
48
        """Get query embedding."""
49
        return self._get_embedding(query)
50

51
    # TODO: use proper async methods
52
    async def _aget_text_embedding(self, query: str) -> List[float]:
53
        """Get text embedding."""
54
        return self._get_embedding(query)
55

56
    # TODO: user proper async methods
57
    async def _aget_query_embedding(self, query: str) -> List[float]:
58
        """Get query embedding."""
59
        return self._get_embedding(query)
60

61
    def _get_text_embedding(self, text: str) -> List[float]:
62
        """Get text embedding."""
63
        return self._get_embedding(text)
64

65
    def _get_embedding(self, text: str) -> List[float]:
66
        vectors = self._model([text]).numpy().tolist()
67
        return vectors[0]
68

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.