llama-index
96 строк · 3.2 Кб
1"""Prompt Mixin."""
2
3from abc import ABC, abstractmethod4from collections import defaultdict5from copy import deepcopy6from typing import Dict, Union7
8from llama_index.legacy.prompts.base import BasePromptTemplate9
10HasPromptType = Union["PromptMixin", BasePromptTemplate]11PromptDictType = Dict[str, BasePromptTemplate]12PromptMixinType = Dict[str, "PromptMixin"]13
14
15class PromptMixin(ABC):16"""Prompt mixin.17
18This mixin is used in other modules, like query engines, response synthesizers.
19This shows that the module supports getting, setting prompts,
20both within the immediate module as well as child modules.
21
22"""
23
24def _validate_prompts(25self,26prompts_dict: PromptDictType,27module_dict: PromptMixinType,28) -> None:29"""Validate prompts."""30# check if prompts_dict, module_dict has restricted ":" token31for key in prompts_dict:32if ":" in key:33raise ValueError(f"Prompt key {key} cannot contain ':'.")34
35for key in module_dict:36if ":" in key:37raise ValueError(f"Prompt key {key} cannot contain ':'.")38
39def get_prompts(self) -> Dict[str, BasePromptTemplate]:40"""Get a prompt."""41prompts_dict = self._get_prompts()42module_dict = self._get_prompt_modules()43self._validate_prompts(prompts_dict, module_dict)44
45# avoid modifying the original dict46all_prompts = deepcopy(prompts_dict)47for module_name, prompt_module in module_dict.items():48# append module name to each key in sub-modules by ":"49for key, prompt in prompt_module.get_prompts().items():50all_prompts[f"{module_name}:{key}"] = prompt51return all_prompts52
53def update_prompts(self, prompts_dict: Dict[str, BasePromptTemplate]) -> None:54"""Update prompts.55
56Other prompts will remain in place.
57
58"""
59prompt_modules = self._get_prompt_modules()60
61# update prompts for current module62self._update_prompts(prompts_dict)63
64# get sub-module keys65# mapping from module name to sub-module prompt keys66sub_prompt_dicts: Dict[str, PromptDictType] = defaultdict(dict)67for key in prompts_dict:68if ":" in key:69module_name, sub_key = key.split(":")70sub_prompt_dicts[module_name][sub_key] = prompts_dict[key]71
72# now update prompts for submodules73for module_name, sub_prompt_dict in sub_prompt_dicts.items():74if module_name not in prompt_modules:75raise ValueError(f"Module {module_name} not found.")76module = prompt_modules[module_name]77module.update_prompts(sub_prompt_dict)78
79@abstractmethod80def _get_prompts(self) -> PromptDictType:81"""Get prompts."""82
83@abstractmethod84def _get_prompt_modules(self) -> PromptMixinType:85"""Get prompt sub-modules.86
87Return a dictionary of sub-modules within the current module
88that also implement PromptMixin (so that their prompts can also be get/set).
89
90Can be blank if no sub-modules.
91
92"""
93
94@abstractmethod95def _update_prompts(self, prompts_dict: PromptDictType) -> None:96"""Update prompts."""97