llama-index

Форк
0
227 строк · 7.5 Кб
1
import json
2
from typing import Any, Dict, Sequence, Tuple
3

4
import httpx
5
from httpx import Timeout
6

7
from llama_index.legacy.bridge.pydantic import Field
8
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
9
from llama_index.legacy.core.llms.types import (
10
    ChatMessage,
11
    ChatResponse,
12
    ChatResponseGen,
13
    CompletionResponse,
14
    CompletionResponseGen,
15
    LLMMetadata,
16
    MessageRole,
17
)
18
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
19
from llama_index.legacy.llms.custom import CustomLLM
20

21
DEFAULT_REQUEST_TIMEOUT = 30.0
22

23

24
def get_addtional_kwargs(
25
    response: Dict[str, Any], exclude: Tuple[str, ...]
26
) -> Dict[str, Any]:
27
    return {k: v for k, v in response.items() if k not in exclude}
28

29

30
class Ollama(CustomLLM):
31
    base_url: str = Field(
32
        default="http://localhost:11434",
33
        description="Base url the model is hosted under.",
34
    )
35
    model: str = Field(description="The Ollama model to use.")
36
    temperature: float = Field(
37
        default=0.75,
38
        description="The temperature to use for sampling.",
39
        gte=0.0,
40
        lte=1.0,
41
    )
42
    context_window: int = Field(
43
        default=DEFAULT_CONTEXT_WINDOW,
44
        description="The maximum number of context tokens for the model.",
45
        gt=0,
46
    )
47
    request_timeout: float = Field(
48
        default=DEFAULT_REQUEST_TIMEOUT,
49
        description="The timeout for making http request to Ollama API server",
50
    )
51
    prompt_key: str = Field(
52
        default="prompt", description="The key to use for the prompt in API calls."
53
    )
54
    additional_kwargs: Dict[str, Any] = Field(
55
        default_factory=dict,
56
        description="Additional model parameters for the Ollama API.",
57
    )
58

59
    @classmethod
60
    def class_name(cls) -> str:
61
        return "Ollama_llm"
62

63
    @property
64
    def metadata(self) -> LLMMetadata:
65
        """LLM metadata."""
66
        return LLMMetadata(
67
            context_window=self.context_window,
68
            num_output=DEFAULT_NUM_OUTPUTS,
69
            model_name=self.model,
70
            is_chat_model=True,  # Ollama supports chat API for all models
71
        )
72

73
    @property
74
    def _model_kwargs(self) -> Dict[str, Any]:
75
        base_kwargs = {
76
            "temperature": self.temperature,
77
            "num_ctx": self.context_window,
78
        }
79
        return {
80
            **base_kwargs,
81
            **self.additional_kwargs,
82
        }
83

84
    @llm_chat_callback()
85
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
86
        payload = {
87
            "model": self.model,
88
            "messages": [
89
                {
90
                    "role": message.role.value,
91
                    "content": message.content,
92
                    **message.additional_kwargs,
93
                }
94
                for message in messages
95
            ],
96
            "options": self._model_kwargs,
97
            "stream": False,
98
            **kwargs,
99
        }
100

101
        with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
102
            response = client.post(
103
                url=f"{self.base_url}/api/chat",
104
                json=payload,
105
            )
106
            response.raise_for_status()
107
            raw = response.json()
108
            message = raw["message"]
109
            return ChatResponse(
110
                message=ChatMessage(
111
                    content=message.get("content"),
112
                    role=MessageRole(message.get("role")),
113
                    additional_kwargs=get_addtional_kwargs(
114
                        message, ("content", "role")
115
                    ),
116
                ),
117
                raw=raw,
118
                additional_kwargs=get_addtional_kwargs(raw, ("message",)),
119
            )
120

121
    @llm_chat_callback()
122
    def stream_chat(
123
        self, messages: Sequence[ChatMessage], **kwargs: Any
124
    ) -> ChatResponseGen:
125
        payload = {
126
            "model": self.model,
127
            "messages": [
128
                {
129
                    "role": message.role.value,
130
                    "content": message.content,
131
                    **message.additional_kwargs,
132
                }
133
                for message in messages
134
            ],
135
            "options": self._model_kwargs,
136
            "stream": True,
137
            **kwargs,
138
        }
139

140
        with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
141
            with client.stream(
142
                method="POST",
143
                url=f"{self.base_url}/api/chat",
144
                json=payload,
145
            ) as response:
146
                response.raise_for_status()
147
                text = ""
148
                for line in response.iter_lines():
149
                    if line:
150
                        chunk = json.loads(line)
151
                        if "done" in chunk and chunk["done"]:
152
                            break
153
                        message = chunk["message"]
154
                        delta = message.get("content")
155
                        text += delta
156
                        yield ChatResponse(
157
                            message=ChatMessage(
158
                                content=text,
159
                                role=MessageRole(message.get("role")),
160
                                additional_kwargs=get_addtional_kwargs(
161
                                    message, ("content", "role")
162
                                ),
163
                            ),
164
                            delta=delta,
165
                            raw=chunk,
166
                            additional_kwargs=get_addtional_kwargs(chunk, ("message",)),
167
                        )
168

169
    @llm_completion_callback()
170
    def complete(
171
        self, prompt: str, formatted: bool = False, **kwargs: Any
172
    ) -> CompletionResponse:
173
        payload = {
174
            self.prompt_key: prompt,
175
            "model": self.model,
176
            "options": self._model_kwargs,
177
            "stream": False,
178
            **kwargs,
179
        }
180

181
        with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
182
            response = client.post(
183
                url=f"{self.base_url}/api/generate",
184
                json=payload,
185
            )
186
            response.raise_for_status()
187
            raw = response.json()
188
            text = raw.get("response")
189
            return CompletionResponse(
190
                text=text,
191
                raw=raw,
192
                additional_kwargs=get_addtional_kwargs(raw, ("response",)),
193
            )
194

195
    @llm_completion_callback()
196
    def stream_complete(
197
        self, prompt: str, formatted: bool = False, **kwargs: Any
198
    ) -> CompletionResponseGen:
199
        payload = {
200
            self.prompt_key: prompt,
201
            "model": self.model,
202
            "options": self._model_kwargs,
203
            "stream": True,
204
            **kwargs,
205
        }
206

207
        with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
208
            with client.stream(
209
                method="POST",
210
                url=f"{self.base_url}/api/generate",
211
                json=payload,
212
            ) as response:
213
                response.raise_for_status()
214
                text = ""
215
                for line in response.iter_lines():
216
                    if line:
217
                        chunk = json.loads(line)
218
                        delta = chunk.get("response")
219
                        text += delta
220
                        yield CompletionResponse(
221
                            delta=delta,
222
                            text=text,
223
                            raw=chunk,
224
                            additional_kwargs=get_addtional_kwargs(
225
                                chunk, ("response",)
226
                            ),
227
                        )
228

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.