llama-index

Форк
0
480 строк · 17.3 Кб
1
import asyncio
2
import logging
3
from typing import (
4
    TYPE_CHECKING,
5
    Any,
6
    Callable,
7
    Dict,
8
    List,
9
    Literal,
10
    Optional,
11
    Sequence,
12
)
13

14
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
15
from llama_index.legacy.callbacks import CallbackManager
16
from llama_index.legacy.core.llms.types import (
17
    ChatMessage,
18
    ChatResponse,
19
    ChatResponseAsyncGen,
20
    ChatResponseGen,
21
    CompletionResponse,
22
    CompletionResponseAsyncGen,
23
    CompletionResponseGen,
24
    LLMMetadata,
25
)
26
from llama_index.legacy.llms.base import (
27
    llm_chat_callback,
28
    llm_completion_callback,
29
)
30
from llama_index.legacy.llms.generic_utils import (
31
    completion_response_to_chat_response,
32
)
33
from llama_index.legacy.llms.generic_utils import (
34
    messages_to_prompt as generic_messages_to_prompt,
35
)
36
from llama_index.legacy.llms.llm import LLM
37
from llama_index.legacy.types import PydanticProgramMode
38

39
logger = logging.getLogger(__name__)
40

41
if TYPE_CHECKING:
42
    from typing import TypeVar
43

44
    M = TypeVar("M")
45
    T = TypeVar("T")
46
    Metadata = Any
47

48

49
class OpenLLM(LLM):
50
    """OpenLLM LLM."""
51

52
    model_id: str = Field(
53
        description="Given Model ID from HuggingFace Hub. This can be either a pretrained ID or local path. This is synonymous to HuggingFace's '.from_pretrained' first argument"
54
    )
55
    model_version: Optional[str] = Field(
56
        description="Optional model version to save the model as."
57
    )
58
    model_tag: Optional[str] = Field(
59
        description="Optional tag to save to BentoML store."
60
    )
61
    prompt_template: Optional[str] = Field(
62
        description="Optional prompt template to pass for this LLM."
63
    )
64
    backend: Optional[Literal["vllm", "pt"]] = Field(
65
        description="Optional backend to pass for this LLM. By default, it will use vLLM if vLLM is available in local system. Otherwise, it will fallback to PyTorch."
66
    )
67
    quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = Field(
68
        description="Optional quantization methods to use with this LLM. See OpenLLM's --quantize options from `openllm start` for more information."
69
    )
70
    serialization: Literal["safetensors", "legacy"] = Field(
71
        description="Optional serialization methods for this LLM to be save as. Default to 'safetensors', but will fallback to PyTorch pickle `.bin` on some models."
72
    )
73
    trust_remote_code: bool = Field(
74
        description="Optional flag to trust remote code. This is synonymous to Transformers' `trust_remote_code`. Default to False."
75
    )
76
    if TYPE_CHECKING:
77
        from typing import Generic
78

79
        try:
80
            import openllm
81

82
            _llm: openllm.LLM[Any, Any]
83
        except ImportError:
84
            _llm: Any  # type: ignore[no-redef]
85
    else:
86
        _llm: Any = PrivateAttr()
87

88
    def __init__(
89
        self,
90
        model_id: str,
91
        model_version: Optional[str] = None,
92
        model_tag: Optional[str] = None,
93
        prompt_template: Optional[str] = None,
94
        backend: Optional[Literal["vllm", "pt"]] = None,
95
        *args: Any,
96
        quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = None,
97
        serialization: Literal["safetensors", "legacy"] = "safetensors",
98
        trust_remote_code: bool = False,
99
        callback_manager: Optional[CallbackManager] = None,
100
        system_prompt: Optional[str] = None,
101
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
102
        completion_to_prompt: Optional[Callable[[str], str]] = None,
103
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
104
        **attrs: Any,
105
    ):
106
        try:
107
            import openllm
108
        except ImportError:
109
            raise ImportError(
110
                "OpenLLM is not installed. Please install OpenLLM via `pip install openllm`"
111
            )
112
        self._llm = openllm.LLM[Any, Any](
113
            model_id,
114
            model_version=model_version,
115
            model_tag=model_tag,
116
            prompt_template=prompt_template,
117
            system_message=system_prompt,
118
            backend=backend,
119
            quantize=quantize,
120
            serialisation=serialization,
121
            trust_remote_code=trust_remote_code,
122
            embedded=True,
123
            **attrs,
124
        )
125
        if messages_to_prompt is None:
126
            messages_to_prompt = self._tokenizer_messages_to_prompt
127

128
        # NOTE: We need to do this here to ensure model is saved and revision is set correctly.
129
        assert self._llm.bentomodel
130

131
        super().__init__(
132
            model_id=model_id,
133
            model_version=self._llm.revision,
134
            model_tag=str(self._llm.tag),
135
            prompt_template=prompt_template,
136
            backend=self._llm.__llm_backend__,
137
            quantize=self._llm.quantise,
138
            serialization=self._llm._serialisation,
139
            trust_remote_code=self._llm.trust_remote_code,
140
            callback_manager=callback_manager,
141
            system_prompt=system_prompt,
142
            messages_to_prompt=messages_to_prompt,
143
            completion_to_prompt=completion_to_prompt,
144
            pydantic_program_mode=pydantic_program_mode,
145
        )
146

147
    @classmethod
148
    def class_name(cls) -> str:
149
        return "OpenLLM"
150

151
    @property
152
    def metadata(self) -> LLMMetadata:
153
        """LLM metadata."""
154
        return LLMMetadata(
155
            num_output=self._llm.config["max_new_tokens"],
156
            model_name=self.model_id,
157
        )
158

159
    def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
160
        """Use the tokenizer to convert messages to prompt. Fallback to generic."""
161
        if hasattr(self._llm.tokenizer, "apply_chat_template"):
162
            return self._llm.tokenizer.apply_chat_template(
163
                [message.dict() for message in messages],
164
                tokenize=False,
165
                add_generation_prompt=True,
166
            )
167
        return generic_messages_to_prompt(messages)
168

169
    @llm_completion_callback()
170
    def complete(
171
        self, prompt: str, formatted: bool = False, **kwargs: Any
172
    ) -> CompletionResponse:
173
        return asyncio.run(self.acomplete(prompt, **kwargs))
174

175
    @llm_chat_callback()
176
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
177
        return asyncio.run(self.achat(messages, **kwargs))
178

179
    @property
180
    def _loop(self) -> asyncio.AbstractEventLoop:
181
        try:
182
            loop = asyncio.get_running_loop()
183
        except RuntimeError:
184
            loop = asyncio.get_event_loop()
185
        return loop
186

187
    @llm_completion_callback()
188
    def stream_complete(
189
        self, prompt: str, formatted: bool = False, **kwargs: Any
190
    ) -> CompletionResponseGen:
191
        generator = self.astream_complete(prompt, **kwargs)
192
        # Yield items from the queue synchronously
193
        while True:
194
            try:
195
                yield self._loop.run_until_complete(generator.__anext__())
196
            except StopAsyncIteration:
197
                break
198

199
    @llm_chat_callback()
200
    def stream_chat(
201
        self, messages: Sequence[ChatMessage], **kwargs: Any
202
    ) -> ChatResponseGen:
203
        generator = self.astream_chat(messages, **kwargs)
204
        # Yield items from the queue synchronously
205
        while True:
206
            try:
207
                yield self._loop.run_until_complete(generator.__anext__())
208
            except StopAsyncIteration:
209
                break
210

211
    @llm_chat_callback()
212
    async def achat(
213
        self,
214
        messages: Sequence[ChatMessage],
215
        **kwargs: Any,
216
    ) -> ChatResponse:
217
        response = await self.acomplete(self.messages_to_prompt(messages), **kwargs)
218
        return completion_response_to_chat_response(response)
219

220
    @llm_completion_callback()
221
    async def acomplete(
222
        self, prompt: str, formatted: bool = False, **kwargs: Any
223
    ) -> CompletionResponse:
224
        response = await self._llm.generate(prompt, **kwargs)
225
        return CompletionResponse(
226
            text=response.outputs[0].text,
227
            raw=response.model_dump(),
228
            additional_kwargs={
229
                "prompt_token_ids": response.prompt_token_ids,
230
                "prompt_logprobs": response.prompt_logprobs,
231
                "finished": response.finished,
232
                "outputs": {
233
                    "token_ids": response.outputs[0].token_ids,
234
                    "cumulative_logprob": response.outputs[0].cumulative_logprob,
235
                    "logprobs": response.outputs[0].logprobs,
236
                    "finish_reason": response.outputs[0].finish_reason,
237
                },
238
            },
239
        )
240

241
    @llm_chat_callback()
242
    async def astream_chat(
243
        self,
244
        messages: Sequence[ChatMessage],
245
        **kwargs: Any,
246
    ) -> ChatResponseAsyncGen:
247
        async for response_chunk in self.astream_complete(
248
            self.messages_to_prompt(messages), **kwargs
249
        ):
250
            yield completion_response_to_chat_response(response_chunk)
251

252
    @llm_completion_callback()
253
    async def astream_complete(
254
        self, prompt: str, formatted: bool = False, **kwargs: Any
255
    ) -> CompletionResponseAsyncGen:
256
        config = self._llm.config.model_construct_env(**kwargs)
257
        if config["n"] > 1:
258
            logger.warning("Currently only support n=1")
259

260
        texts: List[List[str]] = [[]] * config["n"]
261

262
        async for response_chunk in self._llm.generate_iterator(prompt, **kwargs):
263
            for output in response_chunk.outputs:
264
                texts[output.index].append(output.text)
265
            yield CompletionResponse(
266
                text=response_chunk.outputs[0].text,
267
                delta=response_chunk.outputs[0].text,
268
                raw=response_chunk.model_dump(),
269
                additional_kwargs={
270
                    "prompt_token_ids": response_chunk.prompt_token_ids,
271
                    "prompt_logprobs": response_chunk.prompt_logprobs,
272
                    "finished": response_chunk.finished,
273
                    "outputs": {
274
                        "text": response_chunk.outputs[0].text,
275
                        "token_ids": response_chunk.outputs[0].token_ids,
276
                        "cumulative_logprob": response_chunk.outputs[
277
                            0
278
                        ].cumulative_logprob,
279
                        "logprobs": response_chunk.outputs[0].logprobs,
280
                        "finish_reason": response_chunk.outputs[0].finish_reason,
281
                    },
282
                },
283
            )
284

285

286
class OpenLLMAPI(LLM):
287
    """OpenLLM Client interface. This is useful when interacting with a remote OpenLLM server."""
288

289
    address: Optional[str] = Field(
290
        description="OpenLLM server address. This could either be set here or via OPENLLM_ENDPOINT"
291
    )
292
    timeout: int = Field(description="Timeout for sending requests.")
293
    max_retries: int = Field(description="Maximum number of retries.")
294
    api_version: Literal["v1"] = Field(description="OpenLLM Server API version.")
295

296
    if TYPE_CHECKING:
297
        try:
298
            from openllm_client import AsyncHTTPClient, HTTPClient
299

300
            _sync_client: HTTPClient
301
            _async_client: AsyncHTTPClient
302
        except ImportError:
303
            _sync_client: Any  # type: ignore[no-redef]
304
            _async_client: Any  # type: ignore[no-redef]
305
    else:
306
        _sync_client: Any = PrivateAttr()
307
        _async_client: Any = PrivateAttr()
308

309
    def __init__(
310
        self,
311
        address: Optional[str] = None,
312
        timeout: int = 30,
313
        max_retries: int = 2,
314
        api_version: Literal["v1"] = "v1",
315
        **kwargs: Any,
316
    ):
317
        try:
318
            from openllm_client import AsyncHTTPClient, HTTPClient
319
        except ImportError:
320
            raise ImportError(
321
                f'"{type(self).__name__}" requires "openllm-client". Make sure to install with `pip install openllm-client`'
322
            )
323
        super().__init__(
324
            address=address,
325
            timeout=timeout,
326
            max_retries=max_retries,
327
            api_version=api_version,
328
            **kwargs,
329
        )
330
        self._sync_client = HTTPClient(
331
            address=address,
332
            timeout=timeout,
333
            max_retries=max_retries,
334
            api_version=api_version,
335
        )
336
        self._async_client = AsyncHTTPClient(
337
            address=address,
338
            timeout=timeout,
339
            max_retries=max_retries,
340
            api_version=api_version,
341
        )
342

343
    @classmethod
344
    def class_name(cls) -> str:
345
        return "OpenLLM_Client"
346

347
    @property
348
    def _server_metadata(self) -> "Metadata":
349
        return self._sync_client._metadata
350

351
    @property
352
    def _server_config(self) -> Dict[str, Any]:
353
        return self._sync_client._config
354

355
    @property
356
    def metadata(self) -> LLMMetadata:
357
        return LLMMetadata(
358
            num_output=self._server_config["max_new_tokens"],
359
            model_name=self._server_metadata.model_id.replace("/", "--"),
360
        )
361

362
    def _convert_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
363
        return self._sync_client.helpers.messages(
364
            messages=[
365
                {"role": message.role, "content": message.content}
366
                for message in messages
367
            ],
368
            add_generation_prompt=True,
369
        )
370

371
    async def _async_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
372
        return await self._async_client.helpers.messages(
373
            messages=[
374
                {"role": message.role, "content": message.content}
375
                for message in messages
376
            ],
377
            add_generation_prompt=True,
378
        )
379

380
    @llm_completion_callback()
381
    def complete(
382
        self, prompt: str, formatted: bool = False, **kwargs: Any
383
    ) -> CompletionResponse:
384
        response = self._sync_client.generate(prompt, **kwargs)
385
        return CompletionResponse(
386
            text=response.outputs[0].text,
387
            raw=response.model_dump(),
388
            additional_kwargs={
389
                "prompt_token_ids": response.prompt_token_ids,
390
                "prompt_logprobs": response.prompt_logprobs,
391
                "finished": response.finished,
392
                "outputs": {
393
                    "token_ids": response.outputs[0].token_ids,
394
                    "cumulative_logprob": response.outputs[0].cumulative_logprob,
395
                    "logprobs": response.outputs[0].logprobs,
396
                    "finish_reason": response.outputs[0].finish_reason,
397
                },
398
            },
399
        )
400

401
    @llm_completion_callback()
402
    def stream_complete(
403
        self, prompt: str, formatted: bool = False, **kwargs: Any
404
    ) -> CompletionResponseGen:
405
        for response_chunk in self._sync_client.generate_stream(prompt, **kwargs):
406
            yield CompletionResponse(
407
                text=response_chunk.text,
408
                delta=response_chunk.text,
409
                raw=response_chunk.model_dump(),
410
                additional_kwargs={"token_ids": response_chunk.token_ids},
411
            )
412

413
    @llm_chat_callback()
414
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
415
        return completion_response_to_chat_response(
416
            self.complete(self._convert_messages_to_prompt(messages), **kwargs)
417
        )
418

419
    @llm_chat_callback()
420
    def stream_chat(
421
        self, messages: Sequence[ChatMessage], **kwargs: Any
422
    ) -> ChatResponseGen:
423
        for response_chunk in self.stream_complete(
424
            self._convert_messages_to_prompt(messages), **kwargs
425
        ):
426
            yield completion_response_to_chat_response(response_chunk)
427

428
    @llm_completion_callback()
429
    async def acomplete(
430
        self, prompt: str, formatted: bool = False, **kwargs: Any
431
    ) -> CompletionResponse:
432
        response = await self._async_client.generate(prompt, **kwargs)
433
        return CompletionResponse(
434
            text=response.outputs[0].text,
435
            raw=response.model_dump(),
436
            additional_kwargs={
437
                "prompt_token_ids": response.prompt_token_ids,
438
                "prompt_logprobs": response.prompt_logprobs,
439
                "finished": response.finished,
440
                "outputs": {
441
                    "token_ids": response.outputs[0].token_ids,
442
                    "cumulative_logprob": response.outputs[0].cumulative_logprob,
443
                    "logprobs": response.outputs[0].logprobs,
444
                    "finish_reason": response.outputs[0].finish_reason,
445
                },
446
            },
447
        )
448

449
    @llm_completion_callback()
450
    async def astream_complete(
451
        self, prompt: str, formatted: bool = False, **kwargs: Any
452
    ) -> CompletionResponseAsyncGen:
453
        async for response_chunk in self._async_client.generate_stream(
454
            prompt, **kwargs
455
        ):
456
            yield CompletionResponse(
457
                text=response_chunk.text,
458
                delta=response_chunk.text,
459
                raw=response_chunk.model_dump(),
460
                additional_kwargs={"token_ids": response_chunk.token_ids},
461
            )
462

463
    @llm_chat_callback()
464
    async def achat(
465
        self, messages: Sequence[ChatMessage], **kwargs: Any
466
    ) -> ChatResponse:
467
        return completion_response_to_chat_response(
468
            await self.acomplete(
469
                await self._async_messages_to_prompt(messages), **kwargs
470
            )
471
        )
472

473
    @llm_chat_callback()
474
    async def astream_chat(
475
        self, messages: Sequence[ChatMessage], **kwargs: Any
476
    ) -> ChatResponseAsyncGen:
477
        async for response_chunk in self.astream_complete(
478
            await self._async_messages_to_prompt(messages), **kwargs
479
        ):
480
            yield completion_response_to_chat_response(response_chunk)
481

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.