llama-index

Форк
0
573 строки · 19.9 Кб
1
"""Prompts."""
2

3
from abc import ABC, abstractmethod
4
from copy import deepcopy
5
from typing import (
6
    TYPE_CHECKING,
7
    Any,
8
    Callable,
9
    Dict,
10
    List,
11
    Optional,
12
    Sequence,
13
    Tuple,
14
    Union,
15
)
16

17
from llama_index.legacy.bridge.pydantic import Field
18

19
if TYPE_CHECKING:
20
    from llama_index.legacy.bridge.langchain import (
21
        BasePromptTemplate as LangchainTemplate,
22
    )
23
    from llama_index.legacy.bridge.langchain import (
24
        ConditionalPromptSelector as LangchainSelector,
25
    )
26
from llama_index.legacy.bridge.pydantic import BaseModel
27
from llama_index.legacy.core.llms.types import ChatMessage
28
from llama_index.legacy.core.query_pipeline.query_component import (
29
    ChainableMixin,
30
    InputKeys,
31
    OutputKeys,
32
    QueryComponent,
33
    validate_and_convert_stringable,
34
)
35
from llama_index.legacy.llms.base import BaseLLM
36
from llama_index.legacy.llms.generic_utils import (
37
    messages_to_prompt as default_messages_to_prompt,
38
)
39
from llama_index.legacy.llms.generic_utils import (
40
    prompt_to_messages,
41
)
42
from llama_index.legacy.prompts.prompt_type import PromptType
43
from llama_index.legacy.prompts.utils import get_template_vars
44
from llama_index.legacy.types import BaseOutputParser
45

46

47
class BasePromptTemplate(ChainableMixin, BaseModel, ABC):
48
    metadata: Dict[str, Any]
49
    template_vars: List[str]
50
    kwargs: Dict[str, str]
51
    output_parser: Optional[BaseOutputParser]
52
    template_var_mappings: Optional[Dict[str, Any]] = Field(
53
        default_factory=dict, description="Template variable mappings (Optional)."
54
    )
55
    function_mappings: Optional[Dict[str, Callable]] = Field(
56
        default_factory=dict,
57
        description=(
58
            "Function mappings (Optional). This is a mapping from template "
59
            "variable names to functions that take in the current kwargs and "
60
            "return a string."
61
        ),
62
    )
63

64
    def _map_template_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
65
        """For keys in template_var_mappings, swap in the right keys."""
66
        template_var_mappings = self.template_var_mappings or {}
67
        return {template_var_mappings.get(k, k): v for k, v in kwargs.items()}
68

69
    def _map_function_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
70
        """For keys in function_mappings, compute values and combine w/ kwargs.
71

72
        Users can pass in functions instead of fixed values as format variables.
73
        For each function, we call the function with the current kwargs,
74
        get back the value, and then use that value in the template
75
        for the corresponding format variable.
76

77
        """
78
        function_mappings = self.function_mappings or {}
79
        # first generate the values for the functions
80
        new_kwargs = {}
81
        for k, v in function_mappings.items():
82
            # TODO: figure out what variables to pass into each function
83
            # is it the kwargs specified during query time? just the fixed kwargs?
84
            # all kwargs?
85
            new_kwargs[k] = v(**kwargs)
86

87
        # then, add the fixed variables only if not in new_kwargs already
88
        # (implying that function mapping will override fixed variables)
89
        for k, v in kwargs.items():
90
            if k not in new_kwargs:
91
                new_kwargs[k] = v
92

93
        return new_kwargs
94

95
    def _map_all_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
96
        """Map both template and function variables.
97

98
        We (1) first call function mappings to compute functions,
99
        and then (2) call the template_var_mappings.
100

101
        """
102
        # map function
103
        new_kwargs = self._map_function_vars(kwargs)
104
        # map template vars (to point to existing format vars in string template)
105
        return self._map_template_vars(new_kwargs)
106

107
    class Config:
108
        arbitrary_types_allowed = True
109

110
    @abstractmethod
111
    def partial_format(self, **kwargs: Any) -> "BasePromptTemplate":
112
        ...
113

114
    @abstractmethod
115
    def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
116
        ...
117

118
    @abstractmethod
119
    def format_messages(
120
        self, llm: Optional[BaseLLM] = None, **kwargs: Any
121
    ) -> List[ChatMessage]:
122
        ...
123

124
    @abstractmethod
125
    def get_template(self, llm: Optional[BaseLLM] = None) -> str:
126
        ...
127

128
    def _as_query_component(
129
        self, llm: Optional[BaseLLM] = None, **kwargs: Any
130
    ) -> QueryComponent:
131
        """As query component."""
132
        return PromptComponent(prompt=self, format_messages=False, llm=llm)
133

134

135
class PromptTemplate(BasePromptTemplate):
136
    template: str
137

138
    def __init__(
139
        self,
140
        template: str,
141
        prompt_type: str = PromptType.CUSTOM,
142
        output_parser: Optional[BaseOutputParser] = None,
143
        metadata: Optional[Dict[str, Any]] = None,
144
        template_var_mappings: Optional[Dict[str, Any]] = None,
145
        function_mappings: Optional[Dict[str, Callable]] = None,
146
        **kwargs: Any,
147
    ) -> None:
148
        if metadata is None:
149
            metadata = {}
150
        metadata["prompt_type"] = prompt_type
151

152
        template_vars = get_template_vars(template)
153

154
        super().__init__(
155
            template=template,
156
            template_vars=template_vars,
157
            kwargs=kwargs,
158
            metadata=metadata,
159
            output_parser=output_parser,
160
            template_var_mappings=template_var_mappings,
161
            function_mappings=function_mappings,
162
        )
163

164
    def partial_format(self, **kwargs: Any) -> "PromptTemplate":
165
        """Partially format the prompt."""
166
        # NOTE: this is a hack to get around deepcopy failing on output parser
167
        output_parser = self.output_parser
168
        self.output_parser = None
169

170
        # get function and fixed kwargs, and add that to a copy
171
        # of the current prompt object
172
        prompt = deepcopy(self)
173
        prompt.kwargs.update(kwargs)
174

175
        # NOTE: put the output parser back
176
        prompt.output_parser = output_parser
177
        self.output_parser = output_parser
178
        return prompt
179

180
    def format(
181
        self,
182
        llm: Optional[BaseLLM] = None,
183
        completion_to_prompt: Optional[Callable[[str], str]] = None,
184
        **kwargs: Any,
185
    ) -> str:
186
        """Format the prompt into a string."""
187
        del llm  # unused
188
        all_kwargs = {
189
            **self.kwargs,
190
            **kwargs,
191
        }
192

193
        mapped_all_kwargs = self._map_all_vars(all_kwargs)
194
        prompt = self.template.format(**mapped_all_kwargs)
195

196
        if self.output_parser is not None:
197
            prompt = self.output_parser.format(prompt)
198

199
        if completion_to_prompt is not None:
200
            prompt = completion_to_prompt(prompt)
201

202
        return prompt
203

204
    def format_messages(
205
        self, llm: Optional[BaseLLM] = None, **kwargs: Any
206
    ) -> List[ChatMessage]:
207
        """Format the prompt into a list of chat messages."""
208
        del llm  # unused
209
        prompt = self.format(**kwargs)
210
        return prompt_to_messages(prompt)
211

212
    def get_template(self, llm: Optional[BaseLLM] = None) -> str:
213
        return self.template
214

215

216
class ChatPromptTemplate(BasePromptTemplate):
217
    message_templates: List[ChatMessage]
218

219
    def __init__(
220
        self,
221
        message_templates: List[ChatMessage],
222
        prompt_type: str = PromptType.CUSTOM,
223
        output_parser: Optional[BaseOutputParser] = None,
224
        metadata: Optional[Dict[str, Any]] = None,
225
        template_var_mappings: Optional[Dict[str, Any]] = None,
226
        function_mappings: Optional[Dict[str, Callable]] = None,
227
        **kwargs: Any,
228
    ):
229
        if metadata is None:
230
            metadata = {}
231
        metadata["prompt_type"] = prompt_type
232

233
        template_vars = []
234
        for message_template in message_templates:
235
            template_vars.extend(get_template_vars(message_template.content or ""))
236

237
        super().__init__(
238
            message_templates=message_templates,
239
            kwargs=kwargs,
240
            metadata=metadata,
241
            output_parser=output_parser,
242
            template_vars=template_vars,
243
            template_var_mappings=template_var_mappings,
244
            function_mappings=function_mappings,
245
        )
246

247
    def partial_format(self, **kwargs: Any) -> "ChatPromptTemplate":
248
        prompt = deepcopy(self)
249
        prompt.kwargs.update(kwargs)
250
        return prompt
251

252
    def format(
253
        self,
254
        llm: Optional[BaseLLM] = None,
255
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
256
        **kwargs: Any,
257
    ) -> str:
258
        del llm  # unused
259
        messages = self.format_messages(**kwargs)
260

261
        if messages_to_prompt is not None:
262
            return messages_to_prompt(messages)
263

264
        return default_messages_to_prompt(messages)
265

266
    def format_messages(
267
        self, llm: Optional[BaseLLM] = None, **kwargs: Any
268
    ) -> List[ChatMessage]:
269
        del llm  # unused
270
        """Format the prompt into a list of chat messages."""
271
        all_kwargs = {
272
            **self.kwargs,
273
            **kwargs,
274
        }
275
        mapped_all_kwargs = self._map_all_vars(all_kwargs)
276

277
        messages: List[ChatMessage] = []
278
        for message_template in self.message_templates:
279
            template_vars = get_template_vars(message_template.content or "")
280
            relevant_kwargs = {
281
                k: v for k, v in mapped_all_kwargs.items() if k in template_vars
282
            }
283
            content_template = message_template.content or ""
284

285
            # if there's mappings specified, make sure those are used
286
            content = content_template.format(**relevant_kwargs)
287

288
            message: ChatMessage = message_template.copy()
289
            message.content = content
290
            messages.append(message)
291

292
        if self.output_parser is not None:
293
            messages = self.output_parser.format_messages(messages)
294

295
        return messages
296

297
    def get_template(self, llm: Optional[BaseLLM] = None) -> str:
298
        return default_messages_to_prompt(self.message_templates)
299

300
    def _as_query_component(
301
        self, llm: Optional[BaseLLM] = None, **kwargs: Any
302
    ) -> QueryComponent:
303
        """As query component."""
304
        return PromptComponent(prompt=self, format_messages=True, llm=llm)
305

306

307
class SelectorPromptTemplate(BasePromptTemplate):
308
    default_template: BasePromptTemplate
309
    conditionals: Optional[
310
        List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]]
311
    ] = None
312

313
    def __init__(
314
        self,
315
        default_template: BasePromptTemplate,
316
        conditionals: Optional[
317
            List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]]
318
        ] = None,
319
    ):
320
        metadata = default_template.metadata
321
        kwargs = default_template.kwargs
322
        template_vars = default_template.template_vars
323
        output_parser = default_template.output_parser
324
        super().__init__(
325
            default_template=default_template,
326
            conditionals=conditionals,
327
            metadata=metadata,
328
            kwargs=kwargs,
329
            template_vars=template_vars,
330
            output_parser=output_parser,
331
        )
332

333
    def select(self, llm: Optional[BaseLLM] = None) -> BasePromptTemplate:
334
        # ensure output parser is up to date
335
        self.default_template.output_parser = self.output_parser
336

337
        if llm is None:
338
            return self.default_template
339

340
        if self.conditionals is not None:
341
            for condition, prompt in self.conditionals:
342
                if condition(llm):
343
                    # ensure output parser is up to date
344
                    prompt.output_parser = self.output_parser
345
                    return prompt
346

347
        return self.default_template
348

349
    def partial_format(self, **kwargs: Any) -> "SelectorPromptTemplate":
350
        default_template = self.default_template.partial_format(**kwargs)
351
        if self.conditionals is None:
352
            conditionals = None
353
        else:
354
            conditionals = [
355
                (condition, prompt.partial_format(**kwargs))
356
                for condition, prompt in self.conditionals
357
            ]
358
        return SelectorPromptTemplate(
359
            default_template=default_template, conditionals=conditionals
360
        )
361

362
    def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
363
        """Format the prompt into a string."""
364
        prompt = self.select(llm=llm)
365
        return prompt.format(**kwargs)
366

367
    def format_messages(
368
        self, llm: Optional[BaseLLM] = None, **kwargs: Any
369
    ) -> List[ChatMessage]:
370
        """Format the prompt into a list of chat messages."""
371
        prompt = self.select(llm=llm)
372
        return prompt.format_messages(**kwargs)
373

374
    def get_template(self, llm: Optional[BaseLLM] = None) -> str:
375
        prompt = self.select(llm=llm)
376
        return prompt.get_template(llm=llm)
377

378

379
class LangchainPromptTemplate(BasePromptTemplate):
380
    selector: Any
381
    requires_langchain_llm: bool = False
382

383
    def __init__(
384
        self,
385
        template: Optional["LangchainTemplate"] = None,
386
        selector: Optional["LangchainSelector"] = None,
387
        output_parser: Optional[BaseOutputParser] = None,
388
        prompt_type: str = PromptType.CUSTOM,
389
        metadata: Optional[Dict[str, Any]] = None,
390
        template_var_mappings: Optional[Dict[str, Any]] = None,
391
        function_mappings: Optional[Dict[str, Callable]] = None,
392
        requires_langchain_llm: bool = False,
393
    ) -> None:
394
        try:
395
            from llama_index.legacy.bridge.langchain import (
396
                ConditionalPromptSelector as LangchainSelector,
397
            )
398
        except ImportError:
399
            raise ImportError(
400
                "Must install `llama_index[langchain]` to use LangchainPromptTemplate."
401
            )
402
        if selector is None:
403
            if template is None:
404
                raise ValueError("Must provide either template or selector.")
405
            selector = LangchainSelector(default_prompt=template)
406
        else:
407
            if template is not None:
408
                raise ValueError("Must provide either template or selector.")
409
            selector = selector
410

411
        kwargs = selector.default_prompt.partial_variables
412
        template_vars = selector.default_prompt.input_variables
413

414
        if metadata is None:
415
            metadata = {}
416
        metadata["prompt_type"] = prompt_type
417

418
        super().__init__(
419
            selector=selector,
420
            metadata=metadata,
421
            kwargs=kwargs,
422
            template_vars=template_vars,
423
            output_parser=output_parser,
424
            template_var_mappings=template_var_mappings,
425
            function_mappings=function_mappings,
426
            requires_langchain_llm=requires_langchain_llm,
427
        )
428

429
    def partial_format(self, **kwargs: Any) -> "BasePromptTemplate":
430
        """Partially format the prompt."""
431
        from llama_index.legacy.bridge.langchain import (
432
            ConditionalPromptSelector as LangchainSelector,
433
        )
434

435
        mapped_kwargs = self._map_all_vars(kwargs)
436
        default_prompt = self.selector.default_prompt.partial(**mapped_kwargs)
437
        conditionals = [
438
            (condition, prompt.partial(**mapped_kwargs))
439
            for condition, prompt in self.selector.conditionals
440
        ]
441
        lc_selector = LangchainSelector(
442
            default_prompt=default_prompt, conditionals=conditionals
443
        )
444

445
        # copy full prompt object, replace selector
446
        lc_prompt = deepcopy(self)
447
        lc_prompt.selector = lc_selector
448
        return lc_prompt
449

450
    def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
451
        """Format the prompt into a string."""
452
        from llama_index.legacy.llms.langchain import LangChainLLM
453

454
        if llm is not None:
455
            # if llamaindex LLM is provided, and we require a langchain LLM,
456
            # then error. but otherwise if `requires_langchain_llm` is False,
457
            # then we can just use the default prompt
458
            if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm:
459
                raise ValueError("Must provide a LangChainLLM.")
460
            elif not isinstance(llm, LangChainLLM):
461
                lc_template = self.selector.default_prompt
462
            else:
463
                lc_template = self.selector.get_prompt(llm=llm.llm)
464
        else:
465
            lc_template = self.selector.default_prompt
466

467
        # if there's mappings specified, make sure those are used
468
        mapped_kwargs = self._map_all_vars(kwargs)
469
        return lc_template.format(**mapped_kwargs)
470

471
    def format_messages(
472
        self, llm: Optional[BaseLLM] = None, **kwargs: Any
473
    ) -> List[ChatMessage]:
474
        """Format the prompt into a list of chat messages."""
475
        from llama_index.legacy.llms.langchain import LangChainLLM
476
        from llama_index.legacy.llms.langchain_utils import from_lc_messages
477

478
        if llm is not None:
479
            # if llamaindex LLM is provided, and we require a langchain LLM,
480
            # then error. but otherwise if `requires_langchain_llm` is False,
481
            # then we can just use the default prompt
482
            if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm:
483
                raise ValueError("Must provide a LangChainLLM.")
484
            elif not isinstance(llm, LangChainLLM):
485
                lc_template = self.selector.default_prompt
486
            else:
487
                lc_template = self.selector.get_prompt(llm=llm.llm)
488
        else:
489
            lc_template = self.selector.default_prompt
490

491
        # if there's mappings specified, make sure those are used
492
        mapped_kwargs = self._map_all_vars(kwargs)
493
        lc_prompt_value = lc_template.format_prompt(**mapped_kwargs)
494
        lc_messages = lc_prompt_value.to_messages()
495
        return from_lc_messages(lc_messages)
496

497
    def get_template(self, llm: Optional[BaseLLM] = None) -> str:
498
        from llama_index.legacy.llms.langchain import LangChainLLM
499

500
        if llm is not None:
501
            # if llamaindex LLM is provided, and we require a langchain LLM,
502
            # then error. but otherwise if `requires_langchain_llm` is False,
503
            # then we can just use the default prompt
504
            if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm:
505
                raise ValueError("Must provide a LangChainLLM.")
506
            elif not isinstance(llm, LangChainLLM):
507
                lc_template = self.selector.default_prompt
508
            else:
509
                lc_template = self.selector.get_prompt(llm=llm.llm)
510
        else:
511
            lc_template = self.selector.default_prompt
512

513
        try:
514
            return str(lc_template.template)  # type: ignore
515
        except AttributeError:
516
            return str(lc_template)
517

518

519
# NOTE: only for backwards compatibility
520
Prompt = PromptTemplate
521

522

523
class PromptComponent(QueryComponent):
524
    """Prompt component."""
525

526
    prompt: BasePromptTemplate = Field(..., description="Prompt")
527
    llm: Optional[BaseLLM] = Field(
528
        default=None, description="LLM to use for formatting prompt."
529
    )
530
    format_messages: bool = Field(
531
        default=False,
532
        description="Whether to format the prompt into a list of chat messages.",
533
    )
534

535
    class Config:
536
        arbitrary_types_allowed = True
537

538
    def set_callback_manager(self, callback_manager: Any) -> None:
539
        """Set callback manager."""
540

541
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
542
        """Validate component inputs during run_component."""
543
        keys = list(input.keys())
544
        for k in keys:
545
            input[k] = validate_and_convert_stringable(input[k])
546
        return input
547

548
    def _run_component(self, **kwargs: Any) -> Any:
549
        """Run component."""
550
        if self.format_messages:
551
            output: Union[str, List[ChatMessage]] = self.prompt.format_messages(
552
                llm=self.llm, **kwargs
553
            )
554
        else:
555
            output = self.prompt.format(llm=self.llm, **kwargs)
556
        return {"prompt": output}
557

558
    async def _arun_component(self, **kwargs: Any) -> Any:
559
        """Run component."""
560
        # NOTE: no native async for prompt
561
        return self._run_component(**kwargs)
562

563
    @property
564
    def input_keys(self) -> InputKeys:
565
        """Input keys."""
566
        return InputKeys.from_keys(
567
            set(self.prompt.template_vars) - set(self.prompt.kwargs)
568
        )
569

570
    @property
571
    def output_keys(self) -> OutputKeys:
572
        """Output keys."""
573
        return OutputKeys.from_keys({"prompt"})
574

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.