llama-index
96 строк · 3.5 Кб
1"""Embedding utils for LlamaIndex."""
2
3import os
4from typing import TYPE_CHECKING, List, Optional, Union
5
6if TYPE_CHECKING:
7from llama_index.legacy.bridge.langchain import Embeddings as LCEmbeddings
8from llama_index.legacy.embeddings.base import BaseEmbedding
9from llama_index.legacy.embeddings.clip import ClipEmbedding
10from llama_index.legacy.embeddings.huggingface import HuggingFaceEmbedding
11from llama_index.legacy.embeddings.huggingface_utils import (
12INSTRUCTOR_MODELS,
13)
14from llama_index.legacy.embeddings.instructor import InstructorEmbedding
15from llama_index.legacy.embeddings.langchain import LangchainEmbedding
16from llama_index.legacy.embeddings.openai import OpenAIEmbedding
17from llama_index.legacy.llms.openai_utils import validate_openai_api_key
18from llama_index.legacy.token_counter.mock_embed_model import MockEmbedding
19from llama_index.legacy.utils import get_cache_dir
20
21EmbedType = Union[BaseEmbedding, "LCEmbeddings", str]
22
23
24def save_embedding(embedding: List[float], file_path: str) -> None:
25"""Save embedding to file."""
26with open(file_path, "w") as f:
27f.write(",".join([str(x) for x in embedding]))
28
29
30def load_embedding(file_path: str) -> List[float]:
31"""Load embedding from file. Will only return first embedding in file."""
32with open(file_path) as f:
33for line in f:
34embedding = [float(x) for x in line.strip().split(",")]
35break
36return embedding
37
38
39def resolve_embed_model(embed_model: Optional[EmbedType] = None) -> BaseEmbedding:
40"""Resolve embed model."""
41try:
42from llama_index.legacy.bridge.langchain import Embeddings as LCEmbeddings
43except ImportError:
44LCEmbeddings = None # type: ignore
45
46if embed_model == "default":
47try:
48embed_model = OpenAIEmbedding()
49validate_openai_api_key(embed_model.api_key)
50except ValueError as e:
51raise ValueError(
52"\n******\n"
53"Could not load OpenAI embedding model. "
54"If you intended to use OpenAI, please check your OPENAI_API_KEY.\n"
55"Original error:\n"
56f"{e!s}"
57"\nConsider using embed_model='local'.\n"
58"Visit our documentation for more embedding options: "
59"https://docs.llamaindex.ai/en/stable/module_guides/models/"
60"embeddings.html#modules"
61"\n******"
62)
63
64# for image embeddings
65if embed_model == "clip":
66embed_model = ClipEmbedding()
67
68if isinstance(embed_model, str):
69splits = embed_model.split(":", 1)
70is_local = splits[0]
71model_name = splits[1] if len(splits) > 1 else None
72if is_local != "local":
73raise ValueError(
74"embed_model must start with str 'local' or of type BaseEmbedding"
75)
76
77cache_folder = os.path.join(get_cache_dir(), "models")
78os.makedirs(cache_folder, exist_ok=True)
79
80if model_name in INSTRUCTOR_MODELS:
81embed_model = InstructorEmbedding(
82model_name=model_name, cache_folder=cache_folder
83)
84else:
85embed_model = HuggingFaceEmbedding(
86model_name=model_name, cache_folder=cache_folder
87)
88
89if LCEmbeddings is not None and isinstance(embed_model, LCEmbeddings):
90embed_model = LangchainEmbedding(embed_model)
91
92if embed_model is None:
93print("Embeddings have been explicitly disabled. Using MockEmbedding.")
94embed_model = MockEmbedding(embed_dim=1)
95
96return embed_model
97