llama-index
102 строки · 3.6 Кб
1from enum import Enum
2from typing import Any, List, Optional
3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5from llama_index.legacy.callbacks import CallbackManager
6from llama_index.legacy.core.embeddings.base import BaseEmbedding
7
8
9class NomicAITaskType(str, Enum):
10SEARCH_QUERY = "search_query"
11SEARCH_DOCUMENT = "search_document"
12CLUSTERING = "clustering"
13CLASSIFICATION = "classification"
14
15
16TASK_TYPES = [
17NomicAITaskType.SEARCH_QUERY,
18NomicAITaskType.SEARCH_DOCUMENT,
19NomicAITaskType.CLUSTERING,
20NomicAITaskType.CLASSIFICATION,
21]
22
23
24class NomicEmbedding(BaseEmbedding):
25"""NomicEmbedding uses the Nomic API to generate embeddings."""
26
27# Instance variables initialized via Pydantic's mechanism
28query_task_type: Optional[str] = Field(description="Query Embedding prefix")
29document_task_type: Optional[str] = Field(description="Document Embedding prefix")
30model_name: str = Field(description="Embedding model name")
31_model: Any = PrivateAttr()
32
33def __init__(
34self,
35model_name: str = "nomic-embed-text-v1",
36embed_batch_size: int = 32,
37api_key: Optional[str] = None,
38callback_manager: Optional[CallbackManager] = None,
39query_task_type: Optional[str] = "search_query",
40document_task_type: Optional[str] = "search_document",
41**kwargs: Any,
42) -> None:
43if query_task_type not in TASK_TYPES or document_task_type not in TASK_TYPES:
44raise ValueError(
45f"Invalid task type {query_task_type}, {document_task_type}. Must be one of {TASK_TYPES}"
46)
47
48try:
49import nomic
50from nomic import embed
51except ImportError:
52raise ImportError(
53"NomicEmbedding requires the 'nomic' package to be installed.\n"
54"Please install it with `pip install nomic`."
55)
56
57if api_key is not None:
58nomic.cli.login(api_key)
59super().__init__(
60model_name=model_name,
61embed_batch_size=embed_batch_size,
62callback_manager=callback_manager,
63_model=embed,
64query_task_type=query_task_type,
65document_task_type=document_task_type,
66**kwargs,
67)
68self._model = embed
69self.model_name = model_name
70self.query_task_type = query_task_type
71self.document_task_type = document_task_type
72
73@classmethod
74def class_name(cls) -> str:
75return "NomicEmbedding"
76
77def _embed(
78self, texts: List[str], task_type: Optional[str] = None
79) -> List[List[float]]:
80"""Embed sentences using NomicAI."""
81result = self._model.text(texts, model=self.model_name, task_type=task_type)
82return result["embeddings"]
83
84def _get_query_embedding(self, query: str) -> List[float]:
85"""Get query embedding."""
86return self._embed([query], task_type=self.query_task_type)[0]
87
88async def _aget_query_embedding(self, query: str) -> List[float]:
89"""Get query embedding async."""
90return self._get_query_embedding(query)
91
92def _get_text_embedding(self, text: str) -> List[float]:
93"""Get text embedding."""
94return self._embed([text], task_type=self.document_task_type)[0]
95
96async def _aget_text_embedding(self, text: str) -> List[float]:
97"""Get text embedding async."""
98return self._get_text_embedding(text)
99
100def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
101"""Get text embeddings."""
102return self._embed(texts, task_type=self.document_task_type)
103