llama-index

Форк
0
293 строки · 9.9 Кб
1
import logging
2
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union, cast
3

4
from llama_index.legacy.agent.openai.utils import resolve_tool_choice
5
from llama_index.legacy.llms.llm import LLM
6
from llama_index.legacy.llms.openai import OpenAI
7
from llama_index.legacy.llms.openai_utils import OpenAIToolCall, to_openai_tool
8
from llama_index.legacy.program.llm_prompt_program import BaseLLMFunctionProgram
9
from llama_index.legacy.program.utils import create_list_model
10
from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate
11
from llama_index.legacy.types import Model
12

13
_logger = logging.getLogger(__name__)
14

15

16
def _default_tool_choice(
17
    output_cls: Type[Model], allow_multiple: bool = False
18
) -> Union[str, Dict[str, Any]]:
19
    """Default OpenAI tool to choose."""
20
    if allow_multiple:
21
        return "auto"
22
    else:
23
        schema = output_cls.schema()
24
        return resolve_tool_choice(schema["title"])
25

26

27
def _get_json_str(raw_str: str, start_idx: int) -> Tuple[Optional[str], int]:
28
    """Extract JSON str from raw string and start index."""
29
    raw_str = raw_str[start_idx:]
30
    stack_count = 0
31
    for i, c in enumerate(raw_str):
32
        if c == "{":
33
            stack_count += 1
34
        if c == "}":
35
            stack_count -= 1
36
            if stack_count == 0:
37
                return raw_str[: i + 1], i + 2 + start_idx
38

39
    return None, start_idx
40

41

42
def _parse_tool_calls(
43
    tool_calls: List[OpenAIToolCall],
44
    output_cls: Type[Model],
45
    allow_multiple: bool = False,
46
    verbose: bool = False,
47
) -> Union[Model, List[Model]]:
48
    outputs = []
49
    for tool_call in tool_calls:
50
        function_call = tool_call.function
51
        # validations to get passed mypy
52
        assert function_call is not None
53
        assert function_call.name is not None
54
        assert function_call.arguments is not None
55
        if verbose:
56
            name = function_call.name
57
            arguments_str = function_call.arguments
58
            print(f"Function call: {name} with args: {arguments_str}")
59

60
        if isinstance(function_call.arguments, dict):
61
            output = output_cls.parse_obj(function_call.arguments)
62
        else:
63
            output = output_cls.parse_raw(function_call.arguments)
64

65
        outputs.append(output)
66

67
    if allow_multiple:
68
        return outputs
69
    else:
70
        if len(outputs) > 1:
71
            _logger.warning(
72
                "Multiple outputs found, returning first one. "
73
                "If you want to return all outputs, set output_multiple=True."
74
            )
75

76
        return outputs[0]
77

78

79
class OpenAIPydanticProgram(BaseLLMFunctionProgram[LLM]):
80
    """
81
    An OpenAI-based function that returns a pydantic model.
82

83
    Note: this interface is not yet stable.
84
    """
85

86
    def __init__(
87
        self,
88
        output_cls: Type[Model],
89
        llm: LLM,
90
        prompt: BasePromptTemplate,
91
        tool_choice: Union[str, Dict[str, Any]],
92
        allow_multiple: bool = False,
93
        verbose: bool = False,
94
    ) -> None:
95
        """Init params."""
96
        self._output_cls = output_cls
97
        self._llm = llm
98
        self._prompt = prompt
99
        self._verbose = verbose
100
        self._allow_multiple = allow_multiple
101
        self._tool_choice = tool_choice
102

103
    @classmethod
104
    def from_defaults(
105
        cls,
106
        output_cls: Type[Model],
107
        prompt_template_str: Optional[str] = None,
108
        prompt: Optional[PromptTemplate] = None,
109
        llm: Optional[LLM] = None,
110
        verbose: bool = False,
111
        allow_multiple: bool = False,
112
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
113
        **kwargs: Any,
114
    ) -> "OpenAIPydanticProgram":
115
        llm = llm or OpenAI(model="gpt-3.5-turbo-0613")
116

117
        if not isinstance(llm, OpenAI):
118
            raise ValueError(
119
                "OpenAIPydanticProgram only supports OpenAI LLMs. " f"Got: {type(llm)}"
120
            )
121

122
        if not llm.metadata.is_function_calling_model:
123
            raise ValueError(
124
                f"Model name {llm.metadata.model_name} does not support "
125
                "function calling API. "
126
            )
127

128
        if prompt is None and prompt_template_str is None:
129
            raise ValueError("Must provide either prompt or prompt_template_str.")
130
        if prompt is not None and prompt_template_str is not None:
131
            raise ValueError("Must provide either prompt or prompt_template_str.")
132
        if prompt_template_str is not None:
133
            prompt = PromptTemplate(prompt_template_str)
134

135
        tool_choice = tool_choice or _default_tool_choice(output_cls, allow_multiple)
136

137
        return cls(
138
            output_cls=output_cls,
139
            llm=llm,
140
            prompt=cast(PromptTemplate, prompt),
141
            tool_choice=tool_choice,
142
            allow_multiple=allow_multiple,
143
            verbose=verbose,
144
        )
145

146
    @property
147
    def output_cls(self) -> Type[Model]:
148
        return self._output_cls
149

150
    @property
151
    def prompt(self) -> BasePromptTemplate:
152
        return self._prompt
153

154
    @prompt.setter
155
    def prompt(self, prompt: BasePromptTemplate) -> None:
156
        self._prompt = prompt
157

158
    def __call__(
159
        self,
160
        llm_kwargs: Optional[Dict[str, Any]] = None,
161
        *args: Any,
162
        **kwargs: Any,
163
    ) -> Union[Model, List[Model]]:
164
        llm_kwargs = llm_kwargs or {}
165
        description = self._description_eval(**kwargs)
166

167
        openai_fn_spec = to_openai_tool(self._output_cls, description=description)
168

169
        messages = self._prompt.format_messages(llm=self._llm, **kwargs)
170

171
        chat_response = self._llm.chat(
172
            messages=messages,
173
            tools=[openai_fn_spec],
174
            tool_choice=self._tool_choice,
175
            **llm_kwargs,
176
        )
177
        message = chat_response.message
178
        if "tool_calls" not in message.additional_kwargs:
179
            raise ValueError(
180
                "Expected tool_calls in ai_message.additional_kwargs, "
181
                "but none found."
182
            )
183

184
        tool_calls = message.additional_kwargs["tool_calls"]
185
        return _parse_tool_calls(
186
            tool_calls,
187
            output_cls=self.output_cls,
188
            allow_multiple=self._allow_multiple,
189
            verbose=self._verbose,
190
        )
191

192
    async def acall(
193
        self,
194
        llm_kwargs: Optional[Dict[str, Any]] = None,
195
        *args: Any,
196
        **kwargs: Any,
197
    ) -> Union[Model, List[Model]]:
198
        llm_kwargs = llm_kwargs or {}
199
        description = self._description_eval(**kwargs)
200

201
        openai_fn_spec = to_openai_tool(self._output_cls, description=description)
202

203
        messages = self._prompt.format_messages(llm=self._llm, **kwargs)
204

205
        chat_response = await self._llm.achat(
206
            messages=messages,
207
            tools=[openai_fn_spec],
208
            tool_choice=self._tool_choice,
209
            **llm_kwargs,
210
        )
211
        message = chat_response.message
212
        if "tool_calls" not in message.additional_kwargs:
213
            raise ValueError(
214
                "Expected function call in ai_message.additional_kwargs, "
215
                "but none found."
216
            )
217

218
        tool_calls = message.additional_kwargs["tool_calls"]
219
        return _parse_tool_calls(
220
            tool_calls,
221
            output_cls=self.output_cls,
222
            allow_multiple=self._allow_multiple,
223
            verbose=self._verbose,
224
        )
225

226
    def stream_list(
227
        self,
228
        llm_kwargs: Optional[Dict[str, Any]] = None,
229
        *args: Any,
230
        **kwargs: Any,
231
    ) -> Generator[Model, None, None]:
232
        """Streams a list of objects."""
233
        llm_kwargs = llm_kwargs or {}
234
        messages = self._prompt.format_messages(llm=self._llm, **kwargs)
235

236
        description = self._description_eval(**kwargs)
237

238
        list_output_cls = create_list_model(self._output_cls)
239
        openai_fn_spec = to_openai_tool(list_output_cls, description=description)
240

241
        chat_response_gen = self._llm.stream_chat(
242
            messages=messages,
243
            tools=[openai_fn_spec],
244
            tool_choice=_default_tool_choice(list_output_cls),
245
            **llm_kwargs,
246
        )
247
        # extract function call arguments
248
        # obj_start_idx finds start position (before a new "{" in JSON)
249
        obj_start_idx: int = -1  # NOTE: uninitialized
250
        for stream_resp in chat_response_gen:
251
            kwargs = stream_resp.message.additional_kwargs
252
            tool_calls = kwargs["tool_calls"]
253
            if len(tool_calls) == 0:
254
                continue
255

256
            # NOTE: right now assume only one tool call
257
            # TODO: handle parallel tool calls in streaming setting
258
            fn_args = kwargs["tool_calls"][0].function.arguments
259

260
            # this is inspired by `get_object` from `MultiTaskBase` in
261
            # the openai_function_call repo
262

263
            if fn_args.find("[") != -1:
264
                if obj_start_idx == -1:
265
                    obj_start_idx = fn_args.find("[") + 1
266
            else:
267
                # keep going until we find the start position
268
                continue
269

270
            new_obj_json_str, obj_start_idx = _get_json_str(fn_args, obj_start_idx)
271
            if new_obj_json_str is not None:
272
                obj_json_str = new_obj_json_str
273
                obj = self._output_cls.parse_raw(obj_json_str)
274
                if self._verbose:
275
                    print(f"Extracted object: {obj.json()}")
276
                yield obj
277

278
    def _description_eval(self, **kwargs: Any) -> Optional[str]:
279
        description = kwargs.get("description", None)
280

281
        ## __doc__ checks if docstring is provided in the Pydantic Model
282
        if not (self._output_cls.__doc__ or description):
283
            raise ValueError(
284
                "Must provide description for your Pydantic Model. Either provide a docstring or add `description=<your_description>` to the method. Required to convert Pydantic Model to OpenAI Function."
285
            )
286

287
        ## If both docstring and description are provided, raise error
288
        if self._output_cls.__doc__ and description:
289
            raise ValueError(
290
                "Must provide either a docstring or a description, not both."
291
            )
292

293
        return description
294

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.