llama-index

Форк
0
102 строки · 3.6 Кб
1
from enum import Enum
2
from typing import Any, List, Optional
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.core.embeddings.base import BaseEmbedding
7

8

9
class NomicAITaskType(str, Enum):
10
    SEARCH_QUERY = "search_query"
11
    SEARCH_DOCUMENT = "search_document"
12
    CLUSTERING = "clustering"
13
    CLASSIFICATION = "classification"
14

15

16
TASK_TYPES = [
17
    NomicAITaskType.SEARCH_QUERY,
18
    NomicAITaskType.SEARCH_DOCUMENT,
19
    NomicAITaskType.CLUSTERING,
20
    NomicAITaskType.CLASSIFICATION,
21
]
22

23

24
class NomicEmbedding(BaseEmbedding):
25
    """NomicEmbedding uses the Nomic API to generate embeddings."""
26

27
    # Instance variables initialized via Pydantic's mechanism
28
    query_task_type: Optional[str] = Field(description="Query Embedding prefix")
29
    document_task_type: Optional[str] = Field(description="Document Embedding prefix")
30
    model_name: str = Field(description="Embedding model name")
31
    _model: Any = PrivateAttr()
32

33
    def __init__(
34
        self,
35
        model_name: str = "nomic-embed-text-v1",
36
        embed_batch_size: int = 32,
37
        api_key: Optional[str] = None,
38
        callback_manager: Optional[CallbackManager] = None,
39
        query_task_type: Optional[str] = "search_query",
40
        document_task_type: Optional[str] = "search_document",
41
        **kwargs: Any,
42
    ) -> None:
43
        if query_task_type not in TASK_TYPES or document_task_type not in TASK_TYPES:
44
            raise ValueError(
45
                f"Invalid task type {query_task_type}, {document_task_type}. Must be one of {TASK_TYPES}"
46
            )
47

48
        try:
49
            import nomic
50
            from nomic import embed
51
        except ImportError:
52
            raise ImportError(
53
                "NomicEmbedding requires the 'nomic' package to be installed.\n"
54
                "Please install it with `pip install nomic`."
55
            )
56

57
        if api_key is not None:
58
            nomic.cli.login(api_key)
59
        super().__init__(
60
            model_name=model_name,
61
            embed_batch_size=embed_batch_size,
62
            callback_manager=callback_manager,
63
            _model=embed,
64
            query_task_type=query_task_type,
65
            document_task_type=document_task_type,
66
            **kwargs,
67
        )
68
        self._model = embed
69
        self.model_name = model_name
70
        self.query_task_type = query_task_type
71
        self.document_task_type = document_task_type
72

73
    @classmethod
74
    def class_name(cls) -> str:
75
        return "NomicEmbedding"
76

77
    def _embed(
78
        self, texts: List[str], task_type: Optional[str] = None
79
    ) -> List[List[float]]:
80
        """Embed sentences using NomicAI."""
81
        result = self._model.text(texts, model=self.model_name, task_type=task_type)
82
        return result["embeddings"]
83

84
    def _get_query_embedding(self, query: str) -> List[float]:
85
        """Get query embedding."""
86
        return self._embed([query], task_type=self.query_task_type)[0]
87

88
    async def _aget_query_embedding(self, query: str) -> List[float]:
89
        """Get query embedding async."""
90
        return self._get_query_embedding(query)
91

92
    def _get_text_embedding(self, text: str) -> List[float]:
93
        """Get text embedding."""
94
        return self._embed([text], task_type=self.document_task_type)[0]
95

96
    async def _aget_text_embedding(self, text: str) -> List[float]:
97
        """Get text embedding async."""
98
        return self._get_text_embedding(text)
99

100
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
101
        """Get text embeddings."""
102
        return self._embed(texts, task_type=self.document_task_type)
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.