llama-index
107 строк · 3.5 Кб
1from typing import Any, List, Literal, Optional
2
3import numpy as np
4
5from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
6from llama_index.legacy.embeddings.base import BaseEmbedding
7
8
9class FastEmbedEmbedding(BaseEmbedding):
10"""
11Qdrant FastEmbedding models.
12FastEmbed is a lightweight, fast, Python library built for embedding generation.
13See more documentation at:
14* https://github.com/qdrant/fastembed/
15* https://qdrant.github.io/fastembed/.
16
17To use this class, you must install the `fastembed` Python package.
18
19`pip install fastembed`
20Example:
21from llama_index.legacy.embeddings import FastEmbedEmbedding
22fastembed = FastEmbedEmbedding()
23"""
24
25model_name: str = Field(
26"BAAI/bge-small-en-v1.5",
27description="Name of the FastEmbedding model to use.\n"
28"Defaults to 'BAAI/bge-small-en-v1.5'.\n"
29"Find the list of supported models at "
30"https://qdrant.github.io/fastembed/examples/Supported_Models/",
31)
32
33max_length: int = Field(
34512,
35description="The maximum number of tokens. Defaults to 512.\n"
36"Unknown behavior for values > 512.",
37)
38
39cache_dir: Optional[str] = Field(
40None,
41description="The path to the cache directory.\n"
42"Defaults to `local_cache` in the parent directory",
43)
44
45threads: Optional[int] = Field(
46None,
47description="The number of threads single onnxruntime session can use.\n"
48"Defaults to None",
49)
50
51doc_embed_type: Literal["default", "passage"] = Field(
52"default",
53description="Type of embedding to use for documents.\n"
54"'default': Uses FastEmbed's default embedding method.\n"
55"'passage': Prefixes the text with 'passage' before embedding.\n"
56"Defaults to 'default'.",
57)
58
59_model: Any = PrivateAttr()
60
61@classmethod
62def class_name(self) -> str:
63return "FastEmbedEmbedding"
64
65def __init__(
66self,
67model_name: Optional[str] = "BAAI/bge-small-en-v1.5",
68max_length: Optional[int] = 512,
69cache_dir: Optional[str] = None,
70threads: Optional[int] = None,
71doc_embed_type: Literal["default", "passage"] = "default",
72):
73super().__init__(
74model_name=model_name,
75max_length=max_length,
76threads=threads,
77doc_embed_type=doc_embed_type,
78)
79try:
80from fastembed.embedding import FlagEmbedding
81
82self._model = FlagEmbedding(
83model_name=model_name,
84max_length=max_length,
85cache_dir=cache_dir,
86threads=threads,
87)
88except ImportError as ie:
89raise ImportError(
90"Could not import 'fastembed' Python package. "
91"Please install it with `pip install fastembed`."
92) from ie
93
94def _get_text_embedding(self, text: str) -> List[float]:
95embeddings: List[np.ndarray]
96if self.doc_embed_type == "passage":
97embeddings = list(self._model.passage_embed(text))
98else:
99embeddings = list(self._model.embed(text))
100return embeddings[0].tolist()
101
102def _get_query_embedding(self, query: str) -> List[float]:
103query_embeddings: np.ndarray = next(self._model.query_embed(query))
104return query_embeddings.tolist()
105
106async def _aget_query_embedding(self, query: str) -> List[float]:
107return self._get_query_embedding(query)
108