llama-index

Форк
0
83 строки · 2.6 Кб
1
from typing import Any, Sequence
2

3
from llama_index.legacy.core.llms.types import (
4
    ChatMessage,
5
    ChatResponse,
6
    ChatResponseAsyncGen,
7
    ChatResponseGen,
8
    CompletionResponse,
9
    CompletionResponseAsyncGen,
10
)
11
from llama_index.legacy.llms.base import (
12
    llm_chat_callback,
13
    llm_completion_callback,
14
)
15
from llama_index.legacy.llms.generic_utils import (
16
    completion_response_to_chat_response,
17
    stream_completion_response_to_chat_response,
18
)
19
from llama_index.legacy.llms.llm import LLM
20

21

22
class CustomLLM(LLM):
23
    """Simple abstract base class for custom LLMs.
24

25
    Subclasses must implement the `__init__`, `_complete`,
26
        `_stream_complete`, and `metadata` methods.
27
    """
28

29
    @llm_chat_callback()
30
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
31
        prompt = self.messages_to_prompt(messages)
32
        completion_response = self.complete(prompt, formatted=True, **kwargs)
33
        return completion_response_to_chat_response(completion_response)
34

35
    @llm_chat_callback()
36
    def stream_chat(
37
        self, messages: Sequence[ChatMessage], **kwargs: Any
38
    ) -> ChatResponseGen:
39
        prompt = self.messages_to_prompt(messages)
40
        completion_response_gen = self.stream_complete(prompt, formatted=True, **kwargs)
41
        return stream_completion_response_to_chat_response(completion_response_gen)
42

43
    @llm_chat_callback()
44
    async def achat(
45
        self,
46
        messages: Sequence[ChatMessage],
47
        **kwargs: Any,
48
    ) -> ChatResponse:
49
        return self.chat(messages, **kwargs)
50

51
    @llm_chat_callback()
52
    async def astream_chat(
53
        self,
54
        messages: Sequence[ChatMessage],
55
        **kwargs: Any,
56
    ) -> ChatResponseAsyncGen:
57
        async def gen() -> ChatResponseAsyncGen:
58
            for message in self.stream_chat(messages, **kwargs):
59
                yield message
60

61
        # NOTE: convert generator to async generator
62
        return gen()
63

64
    @llm_completion_callback()
65
    async def acomplete(
66
        self, prompt: str, formatted: bool = False, **kwargs: Any
67
    ) -> CompletionResponse:
68
        return self.complete(prompt, formatted=formatted, **kwargs)
69

70
    @llm_completion_callback()
71
    async def astream_complete(
72
        self, prompt: str, formatted: bool = False, **kwargs: Any
73
    ) -> CompletionResponseAsyncGen:
74
        async def gen() -> CompletionResponseAsyncGen:
75
            for message in self.stream_complete(prompt, formatted=formatted, **kwargs):
76
                yield message
77

78
        # NOTE: convert generator to async generator
79
        return gen()
80

81
    @classmethod
82
    def class_name(cls) -> str:
83
        return "custom_llm"
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.