llama-index

Форк
0
107 строк · 3.8 Кб
1
from typing import Any, Dict, List, Optional
2

3
from llama_index.legacy.bridge.pydantic import Field
4
from llama_index.legacy.callbacks.base import CallbackManager
5
from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
6
from llama_index.legacy.embeddings.base import BaseEmbedding
7

8

9
class OllamaEmbedding(BaseEmbedding):
10
    """Class for Ollama embeddings."""
11

12
    base_url: str = Field(description="Base url the model is hosted by Ollama")
13
    model_name: str = Field(description="The Ollama model to use.")
14
    embed_batch_size: int = Field(
15
        default=DEFAULT_EMBED_BATCH_SIZE,
16
        description="The batch size for embedding calls.",
17
        gt=0,
18
        lte=2048,
19
    )
20
    ollama_additional_kwargs: Dict[str, Any] = Field(
21
        default_factory=dict, description="Additional kwargs for the Ollama API."
22
    )
23

24
    def __init__(
25
        self,
26
        model_name: str,
27
        base_url: str = "http://localhost:11434",
28
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
29
        ollama_additional_kwargs: Optional[Dict[str, Any]] = None,
30
        callback_manager: Optional[CallbackManager] = None,
31
    ) -> None:
32
        super().__init__(
33
            model_name=model_name,
34
            base_url=base_url,
35
            embed_batch_size=embed_batch_size,
36
            ollama_additional_kwargs=ollama_additional_kwargs or {},
37
            callback_manager=callback_manager,
38
        )
39

40
    @classmethod
41
    def class_name(cls) -> str:
42
        return "OllamaEmbedding"
43

44
    def _get_query_embedding(self, query: str) -> List[float]:
45
        """Get query embedding."""
46
        return self.get_general_text_embedding(query)
47

48
    async def _aget_query_embedding(self, query: str) -> List[float]:
49
        """The asynchronous version of _get_query_embedding."""
50
        return self.get_general_text_embedding(query)
51

52
    def _get_text_embedding(self, text: str) -> List[float]:
53
        """Get text embedding."""
54
        return self.get_general_text_embedding(text)
55

56
    async def _aget_text_embedding(self, text: str) -> List[float]:
57
        """Asynchronously get text embedding."""
58
        return self.get_general_text_embedding(text)
59

60
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
61
        """Get text embeddings."""
62
        embeddings_list: List[List[float]] = []
63
        for text in texts:
64
            embeddings = self.get_general_text_embedding(text)
65
            embeddings_list.append(embeddings)
66

67
        return embeddings_list
68

69
    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
70
        """Asynchronously get text embeddings."""
71
        return self._get_text_embeddings(texts)
72

73
    def get_general_text_embedding(self, prompt: str) -> List[float]:
74
        """Get Ollama embedding."""
75
        try:
76
            import requests
77
        except ImportError:
78
            raise ImportError(
79
                "Could not import requests library."
80
                "Please install requests with `pip install requests`"
81
            )
82

83
        ollama_request_body = {
84
            "prompt": prompt,
85
            "model": self.model_name,
86
            "options": self.ollama_additional_kwargs,
87
        }
88

89
        response = requests.post(
90
            url=f"{self.base_url}/api/embeddings",
91
            headers={"Content-Type": "application/json"},
92
            json=ollama_request_body,
93
        )
94
        response.encoding = "utf-8"
95
        if response.status_code != 200:
96
            optional_detail = response.json().get("error")
97
            raise ValueError(
98
                f"Ollama call failed with status code {response.status_code}."
99
                f" Details: {optional_detail}"
100
            )
101

102
        try:
103
            return response.json()["embedding"]
104
        except requests.exceptions.JSONDecodeError as e:
105
            raise ValueError(
106
                f"Error raised for Ollama Call: {e}.\nResponse: {response.text}"
107
            )
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.