llama-index

Форк
0
87 строк · 3.1 Кб
1
"""Langchain Embedding Wrapper Module."""
2

3
from typing import TYPE_CHECKING, List, Optional
4

5
from llama_index.legacy.bridge.pydantic import PrivateAttr
6
from llama_index.legacy.callbacks import CallbackManager
7
from llama_index.legacy.core.embeddings.base import (
8
    DEFAULT_EMBED_BATCH_SIZE,
9
    BaseEmbedding,
10
)
11

12
if TYPE_CHECKING:
13
    from llama_index.legacy.bridge.langchain import Embeddings as LCEmbeddings
14

15

16
class LangchainEmbedding(BaseEmbedding):
17
    """External embeddings (taken from Langchain).
18

19
    Args:
20
        langchain_embedding (langchain.embeddings.Embeddings): Langchain
21
            embeddings class.
22
    """
23

24
    _langchain_embedding: "LCEmbeddings" = PrivateAttr()
25
    _async_not_implemented_warned: bool = PrivateAttr(default=False)
26

27
    def __init__(
28
        self,
29
        langchain_embeddings: "LCEmbeddings",
30
        model_name: Optional[str] = None,
31
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
32
        callback_manager: Optional[CallbackManager] = None,
33
    ):
34
        # attempt to get a useful model name
35
        if model_name is not None:
36
            model_name = model_name
37
        elif hasattr(langchain_embeddings, "model_name"):
38
            model_name = langchain_embeddings.model_name
39
        elif hasattr(langchain_embeddings, "model"):
40
            model_name = langchain_embeddings.model
41
        else:
42
            model_name = type(langchain_embeddings).__name__
43

44
        self._langchain_embedding = langchain_embeddings
45
        super().__init__(
46
            embed_batch_size=embed_batch_size,
47
            callback_manager=callback_manager,
48
            model_name=model_name,
49
        )
50

51
    @classmethod
52
    def class_name(cls) -> str:
53
        return "LangchainEmbedding"
54

55
    def _async_not_implemented_warn_once(self) -> None:
56
        if not self._async_not_implemented_warned:
57
            print("Async embedding not available, falling back to sync method.")
58
            self._async_not_implemented_warned = True
59

60
    def _get_query_embedding(self, query: str) -> List[float]:
61
        """Get query embedding."""
62
        return self._langchain_embedding.embed_query(query)
63

64
    async def _aget_query_embedding(self, query: str) -> List[float]:
65
        try:
66
            return await self._langchain_embedding.aembed_query(query)
67
        except NotImplementedError:
68
            # Warn the user that sync is being used
69
            self._async_not_implemented_warn_once()
70
            return self._get_query_embedding(query)
71

72
    async def _aget_text_embedding(self, text: str) -> List[float]:
73
        try:
74
            embeds = await self._langchain_embedding.aembed_documents([text])
75
            return embeds[0]
76
        except NotImplementedError:
77
            # Warn the user that sync is being used
78
            self._async_not_implemented_warn_once()
79
            return self._get_text_embedding(text)
80

81
    def _get_text_embedding(self, text: str) -> List[float]:
82
        """Get text embedding."""
83
        return self._langchain_embedding.embed_documents([text])[0]
84

85
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
86
        """Get text embeddings."""
87
        return self._langchain_embedding.embed_documents(texts)
88

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.