llama-index

Форк
0
96 строк · 3.2 Кб
1
"""Prompt Mixin."""
2

3
from abc import ABC, abstractmethod
4
from collections import defaultdict
5
from copy import deepcopy
6
from typing import Dict, Union
7

8
from llama_index.legacy.prompts.base import BasePromptTemplate
9

10
HasPromptType = Union["PromptMixin", BasePromptTemplate]
11
PromptDictType = Dict[str, BasePromptTemplate]
12
PromptMixinType = Dict[str, "PromptMixin"]
13

14

15
class PromptMixin(ABC):
16
    """Prompt mixin.
17

18
    This mixin is used in other modules, like query engines, response synthesizers.
19
    This shows that the module supports getting, setting prompts,
20
    both within the immediate module as well as child modules.
21

22
    """
23

24
    def _validate_prompts(
25
        self,
26
        prompts_dict: PromptDictType,
27
        module_dict: PromptMixinType,
28
    ) -> None:
29
        """Validate prompts."""
30
        # check if prompts_dict, module_dict has restricted ":" token
31
        for key in prompts_dict:
32
            if ":" in key:
33
                raise ValueError(f"Prompt key {key} cannot contain ':'.")
34

35
        for key in module_dict:
36
            if ":" in key:
37
                raise ValueError(f"Prompt key {key} cannot contain ':'.")
38

39
    def get_prompts(self) -> Dict[str, BasePromptTemplate]:
40
        """Get a prompt."""
41
        prompts_dict = self._get_prompts()
42
        module_dict = self._get_prompt_modules()
43
        self._validate_prompts(prompts_dict, module_dict)
44

45
        # avoid modifying the original dict
46
        all_prompts = deepcopy(prompts_dict)
47
        for module_name, prompt_module in module_dict.items():
48
            # append module name to each key in sub-modules by ":"
49
            for key, prompt in prompt_module.get_prompts().items():
50
                all_prompts[f"{module_name}:{key}"] = prompt
51
        return all_prompts
52

53
    def update_prompts(self, prompts_dict: Dict[str, BasePromptTemplate]) -> None:
54
        """Update prompts.
55

56
        Other prompts will remain in place.
57

58
        """
59
        prompt_modules = self._get_prompt_modules()
60

61
        # update prompts for current module
62
        self._update_prompts(prompts_dict)
63

64
        # get sub-module keys
65
        # mapping from module name to sub-module prompt keys
66
        sub_prompt_dicts: Dict[str, PromptDictType] = defaultdict(dict)
67
        for key in prompts_dict:
68
            if ":" in key:
69
                module_name, sub_key = key.split(":")
70
                sub_prompt_dicts[module_name][sub_key] = prompts_dict[key]
71

72
        # now update prompts for submodules
73
        for module_name, sub_prompt_dict in sub_prompt_dicts.items():
74
            if module_name not in prompt_modules:
75
                raise ValueError(f"Module {module_name} not found.")
76
            module = prompt_modules[module_name]
77
            module.update_prompts(sub_prompt_dict)
78

79
    @abstractmethod
80
    def _get_prompts(self) -> PromptDictType:
81
        """Get prompts."""
82

83
    @abstractmethod
84
    def _get_prompt_modules(self) -> PromptMixinType:
85
        """Get prompt sub-modules.
86

87
        Return a dictionary of sub-modules within the current module
88
        that also implement PromptMixin (so that their prompts can also be get/set).
89

90
        Can be blank if no sub-modules.
91

92
        """
93

94
    @abstractmethod
95
    def _update_prompts(self, prompts_dict: PromptDictType) -> None:
96
        """Update prompts."""
97

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.