llama-index
44 строки · 1.9 Кб
1from typing import Dict, Type2
3from llama_index.legacy.embeddings.base import BaseEmbedding4from llama_index.legacy.embeddings.google import GoogleUnivSentEncoderEmbedding5from llama_index.legacy.embeddings.huggingface import HuggingFaceEmbedding6from llama_index.legacy.embeddings.langchain import LangchainEmbedding7from llama_index.legacy.embeddings.openai import OpenAIEmbedding8from llama_index.legacy.embeddings.text_embeddings_inference import (9TextEmbeddingsInference,10)
11from llama_index.legacy.embeddings.utils import resolve_embed_model12from llama_index.legacy.token_counter.mock_embed_model import MockEmbedding13
14RECOGNIZED_EMBEDDINGS: Dict[str, Type[BaseEmbedding]] = {15GoogleUnivSentEncoderEmbedding.class_name(): GoogleUnivSentEncoderEmbedding,16OpenAIEmbedding.class_name(): OpenAIEmbedding,17LangchainEmbedding.class_name(): LangchainEmbedding,18MockEmbedding.class_name(): MockEmbedding,19HuggingFaceEmbedding.class_name(): HuggingFaceEmbedding,20TextEmbeddingsInference.class_name(): TextEmbeddingsInference,21OpenAIEmbedding.class_name(): OpenAIEmbedding,22}
23
24
25def load_embed_model(data: dict) -> BaseEmbedding:26"""Load Embedding by name."""27if isinstance(data, BaseEmbedding):28return data29name = data.get("class_name", None)30if name is None:31raise ValueError("Embedding loading requires a class_name")32if name not in RECOGNIZED_EMBEDDINGS:33raise ValueError(f"Invalid Embedding name: {name}")34
35# special handling for LangchainEmbedding36# it can be any local model technially37if name == LangchainEmbedding.class_name():38local_name = data.get("model_name", None)39if local_name is not None:40return resolve_embed_model("local:" + local_name)41else:42raise ValueError("LangchainEmbedding requires a model_name")43
44return RECOGNIZED_EMBEDDINGS[name].from_dict(data)45