llama-index

Форк
0
225 строк · 7.7 Кб
1
from threading import Thread
2
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Sequence
3

4
if TYPE_CHECKING:
5
    from langchain.base_language import BaseLanguageModel
6

7
from llama_index.legacy.bridge.pydantic import PrivateAttr
8
from llama_index.legacy.callbacks import CallbackManager
9
from llama_index.legacy.core.llms.types import (
10
    ChatMessage,
11
    ChatResponse,
12
    ChatResponseAsyncGen,
13
    ChatResponseGen,
14
    CompletionResponse,
15
    CompletionResponseAsyncGen,
16
    CompletionResponseGen,
17
    LLMMetadata,
18
)
19
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
20
from llama_index.legacy.llms.generic_utils import (
21
    completion_response_to_chat_response,
22
    stream_completion_response_to_chat_response,
23
)
24
from llama_index.legacy.llms.llm import LLM
25
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
26

27

28
class LangChainLLM(LLM):
29
    """Adapter for a LangChain LLM."""
30

31
    _llm: Any = PrivateAttr()
32

33
    def __init__(
34
        self,
35
        llm: "BaseLanguageModel",
36
        callback_manager: Optional[CallbackManager] = None,
37
        system_prompt: Optional[str] = None,
38
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
39
        completion_to_prompt: Optional[Callable[[str], str]] = None,
40
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
41
        output_parser: Optional[BaseOutputParser] = None,
42
    ) -> None:
43
        self._llm = llm
44
        super().__init__(
45
            callback_manager=callback_manager,
46
            system_prompt=system_prompt,
47
            messages_to_prompt=messages_to_prompt,
48
            completion_to_prompt=completion_to_prompt,
49
            pydantic_program_mode=pydantic_program_mode,
50
            output_parser=output_parser,
51
        )
52

53
    @classmethod
54
    def class_name(cls) -> str:
55
        return "LangChainLLM"
56

57
    @property
58
    def llm(self) -> "BaseLanguageModel":
59
        return self._llm
60

61
    @property
62
    def metadata(self) -> LLMMetadata:
63
        from llama_index.legacy.llms.langchain_utils import get_llm_metadata
64

65
        return get_llm_metadata(self._llm)
66

67
    @llm_chat_callback()
68
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
69
        from llama_index.legacy.llms.langchain_utils import (
70
            from_lc_messages,
71
            to_lc_messages,
72
        )
73

74
        if not self.metadata.is_chat_model:
75
            prompt = self.messages_to_prompt(messages)
76
            completion_response = self.complete(prompt, formatted=True, **kwargs)
77
            return completion_response_to_chat_response(completion_response)
78

79
        lc_messages = to_lc_messages(messages)
80
        lc_message = self._llm.predict_messages(messages=lc_messages, **kwargs)
81
        message = from_lc_messages([lc_message])[0]
82
        return ChatResponse(message=message)
83

84
    @llm_completion_callback()
85
    def complete(
86
        self, prompt: str, formatted: bool = False, **kwargs: Any
87
    ) -> CompletionResponse:
88
        if not formatted:
89
            prompt = self.completion_to_prompt(prompt)
90

91
        output_str = self._llm.predict(prompt, **kwargs)
92
        return CompletionResponse(text=output_str)
93

94
    @llm_chat_callback()
95
    def stream_chat(
96
        self, messages: Sequence[ChatMessage], **kwargs: Any
97
    ) -> ChatResponseGen:
98
        if not self.metadata.is_chat_model:
99
            prompt = self.messages_to_prompt(messages)
100
            stream_completion = self.stream_complete(prompt, formatted=True, **kwargs)
101
            return stream_completion_response_to_chat_response(stream_completion)
102

103
        if hasattr(self._llm, "stream"):
104

105
            def gen() -> Generator[ChatResponse, None, None]:
106
                from llama_index.legacy.llms.langchain_utils import (
107
                    from_lc_messages,
108
                    to_lc_messages,
109
                )
110

111
                lc_messages = to_lc_messages(messages)
112
                response_str = ""
113
                for message in self._llm.stream(lc_messages, **kwargs):
114
                    message = from_lc_messages([message])[0]
115
                    delta = message.content
116
                    response_str += delta
117
                    yield ChatResponse(
118
                        message=ChatMessage(role=message.role, content=response_str),
119
                        delta=delta,
120
                    )
121

122
            return gen()
123

124
        else:
125
            from llama_index.legacy.langchain_helpers.streaming import (
126
                StreamingGeneratorCallbackHandler,
127
            )
128

129
            handler = StreamingGeneratorCallbackHandler()
130

131
            if not hasattr(self._llm, "streaming"):
132
                raise ValueError("LLM must support streaming.")
133
            if not hasattr(self._llm, "callbacks"):
134
                raise ValueError("LLM must support callbacks to use streaming.")
135

136
            self._llm.callbacks = [handler]  # type: ignore
137
            self._llm.streaming = True  # type: ignore
138

139
            thread = Thread(target=self.chat, args=[messages], kwargs=kwargs)
140
            thread.start()
141

142
            response_gen = handler.get_response_gen()
143

144
            def gen() -> Generator[ChatResponse, None, None]:
145
                text = ""
146
                for delta in response_gen:
147
                    text += delta
148
                    yield ChatResponse(
149
                        message=ChatMessage(text=text),
150
                        delta=delta,
151
                    )
152

153
            return gen()
154

155
    @llm_completion_callback()
156
    def stream_complete(
157
        self, prompt: str, formatted: bool = False, **kwargs: Any
158
    ) -> CompletionResponseGen:
159
        if not formatted:
160
            prompt = self.completion_to_prompt(prompt)
161

162
        from llama_index.legacy.langchain_helpers.streaming import (
163
            StreamingGeneratorCallbackHandler,
164
        )
165

166
        handler = StreamingGeneratorCallbackHandler()
167

168
        if not hasattr(self._llm, "streaming"):
169
            raise ValueError("LLM must support streaming.")
170
        if not hasattr(self._llm, "callbacks"):
171
            raise ValueError("LLM must support callbacks to use streaming.")
172

173
        self._llm.callbacks = [handler]  # type: ignore
174
        self._llm.streaming = True  # type: ignore
175

176
        thread = Thread(target=self.complete, args=[prompt], kwargs=kwargs)
177
        thread.start()
178

179
        response_gen = handler.get_response_gen()
180

181
        def gen() -> Generator[CompletionResponse, None, None]:
182
            text = ""
183
            for delta in response_gen:
184
                text += delta
185
                yield CompletionResponse(delta=delta, text=text)
186

187
        return gen()
188

189
    @llm_chat_callback()
190
    async def achat(
191
        self, messages: Sequence[ChatMessage], **kwargs: Any
192
    ) -> ChatResponse:
193
        # TODO: Implement async chat
194
        return self.chat(messages, **kwargs)
195

196
    @llm_completion_callback()
197
    async def acomplete(
198
        self, prompt: str, formatted: bool = False, **kwargs: Any
199
    ) -> CompletionResponse:
200
        # TODO: Implement async complete
201
        return self.complete(prompt, formatted=formatted, **kwargs)
202

203
    @llm_chat_callback()
204
    async def astream_chat(
205
        self, messages: Sequence[ChatMessage], **kwargs: Any
206
    ) -> ChatResponseAsyncGen:
207
        # TODO: Implement async stream_chat
208

209
        async def gen() -> ChatResponseAsyncGen:
210
            for message in self.stream_chat(messages, **kwargs):
211
                yield message
212

213
        return gen()
214

215
    @llm_completion_callback()
216
    async def astream_complete(
217
        self, prompt: str, formatted: bool = False, **kwargs: Any
218
    ) -> CompletionResponseAsyncGen:
219
        # TODO: Implement async stream_complete
220

221
        async def gen() -> CompletionResponseAsyncGen:
222
            for response in self.stream_complete(prompt, formatted=formatted, **kwargs):
223
                yield response
224

225
        return gen()
226

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.