llama-index
87 строк · 3.1 Кб
1"""Langchain Embedding Wrapper Module."""
2
3from typing import TYPE_CHECKING, List, Optional4
5from llama_index.legacy.bridge.pydantic import PrivateAttr6from llama_index.legacy.callbacks import CallbackManager7from llama_index.legacy.core.embeddings.base import (8DEFAULT_EMBED_BATCH_SIZE,9BaseEmbedding,10)
11
12if TYPE_CHECKING:13from llama_index.legacy.bridge.langchain import Embeddings as LCEmbeddings14
15
16class LangchainEmbedding(BaseEmbedding):17"""External embeddings (taken from Langchain).18
19Args:
20langchain_embedding (langchain.embeddings.Embeddings): Langchain
21embeddings class.
22"""
23
24_langchain_embedding: "LCEmbeddings" = PrivateAttr()25_async_not_implemented_warned: bool = PrivateAttr(default=False)26
27def __init__(28self,29langchain_embeddings: "LCEmbeddings",30model_name: Optional[str] = None,31embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,32callback_manager: Optional[CallbackManager] = None,33):34# attempt to get a useful model name35if model_name is not None:36model_name = model_name37elif hasattr(langchain_embeddings, "model_name"):38model_name = langchain_embeddings.model_name39elif hasattr(langchain_embeddings, "model"):40model_name = langchain_embeddings.model41else:42model_name = type(langchain_embeddings).__name__43
44self._langchain_embedding = langchain_embeddings45super().__init__(46embed_batch_size=embed_batch_size,47callback_manager=callback_manager,48model_name=model_name,49)50
51@classmethod52def class_name(cls) -> str:53return "LangchainEmbedding"54
55def _async_not_implemented_warn_once(self) -> None:56if not self._async_not_implemented_warned:57print("Async embedding not available, falling back to sync method.")58self._async_not_implemented_warned = True59
60def _get_query_embedding(self, query: str) -> List[float]:61"""Get query embedding."""62return self._langchain_embedding.embed_query(query)63
64async def _aget_query_embedding(self, query: str) -> List[float]:65try:66return await self._langchain_embedding.aembed_query(query)67except NotImplementedError:68# Warn the user that sync is being used69self._async_not_implemented_warn_once()70return self._get_query_embedding(query)71
72async def _aget_text_embedding(self, text: str) -> List[float]:73try:74embeds = await self._langchain_embedding.aembed_documents([text])75return embeds[0]76except NotImplementedError:77# Warn the user that sync is being used78self._async_not_implemented_warn_once()79return self._get_text_embedding(text)80
81def _get_text_embedding(self, text: str) -> List[float]:82"""Get text embedding."""83return self._langchain_embedding.embed_documents([text])[0]84
85def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:86"""Get text embeddings."""87return self._langchain_embedding.embed_documents(texts)88