llama-index

Форк
0
156 строк · 5.7 Кб
1
"""Mock LLM Predictor."""
2

3
from typing import Any, Dict
4

5
from deprecated import deprecated
6

7
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
8
from llama_index.legacy.callbacks.base import CallbackManager
9
from llama_index.legacy.constants import DEFAULT_NUM_OUTPUTS
10
from llama_index.legacy.core.llms.types import LLMMetadata
11
from llama_index.legacy.llm_predictor.base import BaseLLMPredictor
12
from llama_index.legacy.llms.llm import LLM
13
from llama_index.legacy.prompts.base import BasePromptTemplate
14
from llama_index.legacy.prompts.prompt_type import PromptType
15
from llama_index.legacy.token_counter.utils import (
16
    mock_extract_keywords_response,
17
    mock_extract_kg_triplets_response,
18
)
19
from llama_index.legacy.types import TokenAsyncGen, TokenGen
20
from llama_index.legacy.utils import get_tokenizer
21

22
# TODO: consolidate with unit tests in tests/mock_utils/mock_predict.py
23

24

25
def _mock_summary_predict(max_tokens: int, prompt_args: Dict) -> str:
26
    """Mock summary predict."""
27
    # tokens in response shouldn't be larger than tokens in `context_str`
28
    num_text_tokens = len(get_tokenizer()(prompt_args["context_str"]))
29
    token_limit = min(num_text_tokens, max_tokens)
30
    return " ".join(["summary"] * token_limit)
31

32

33
def _mock_insert_predict() -> str:
34
    """Mock insert predict."""
35
    return "ANSWER: 1"
36

37

38
def _mock_query_select() -> str:
39
    """Mock query select."""
40
    return "ANSWER: 1"
41

42

43
def _mock_query_select_multiple(num_chunks: int) -> str:
44
    """Mock query select."""
45
    nums_str = ", ".join([str(i) for i in range(num_chunks)])
46
    return f"ANSWER: {nums_str}"
47

48

49
def _mock_answer(max_tokens: int, prompt_args: Dict) -> str:
50
    """Mock answer."""
51
    # tokens in response shouldn't be larger than tokens in `text`
52
    num_ctx_tokens = len(get_tokenizer()(prompt_args["context_str"]))
53
    token_limit = min(num_ctx_tokens, max_tokens)
54
    return " ".join(["answer"] * token_limit)
55

56

57
def _mock_refine(max_tokens: int, prompt: BasePromptTemplate, prompt_args: Dict) -> str:
58
    """Mock refine."""
59
    # tokens in response shouldn't be larger than tokens in
60
    # `existing_answer` + `context_msg`
61
    # NOTE: if existing_answer is not in prompt_args, we need to get it from the prompt
62
    if "existing_answer" not in prompt_args:
63
        existing_answer = prompt.kwargs["existing_answer"]
64
    else:
65
        existing_answer = prompt_args["existing_answer"]
66
    num_ctx_tokens = len(get_tokenizer()(prompt_args["context_msg"]))
67
    num_exist_tokens = len(get_tokenizer()(existing_answer))
68
    token_limit = min(num_ctx_tokens + num_exist_tokens, max_tokens)
69
    return " ".join(["answer"] * token_limit)
70

71

72
def _mock_keyword_extract(prompt_args: Dict) -> str:
73
    """Mock keyword extract."""
74
    return mock_extract_keywords_response(prompt_args["text"])
75

76

77
def _mock_query_keyword_extract(prompt_args: Dict) -> str:
78
    """Mock query keyword extract."""
79
    return mock_extract_keywords_response(prompt_args["question"])
80

81

82
def _mock_knowledge_graph_triplet_extract(prompt_args: Dict, max_triplets: int) -> str:
83
    """Mock knowledge graph triplet extract."""
84
    return mock_extract_kg_triplets_response(
85
        prompt_args["text"], max_triplets=max_triplets
86
    )
87

88

89
@deprecated("MockLLMPredictor is deprecated. Use MockLLM instead.")
90
class MockLLMPredictor(BaseLLMPredictor):
91
    """Mock LLM Predictor."""
92

93
    max_tokens: int = Field(
94
        default=DEFAULT_NUM_OUTPUTS, description="Number of tokens to mock generate."
95
    )
96

97
    _callback_manager: CallbackManager = PrivateAttr(default_factory=CallbackManager)
98

99
    @classmethod
100
    def class_name(cls) -> str:
101
        return "MockLLMPredictor"
102

103
    @property
104
    def metadata(self) -> LLMMetadata:
105
        return LLMMetadata()
106

107
    @property
108
    def callback_manager(self) -> CallbackManager:
109
        return self.callback_manager
110

111
    @property
112
    def llm(self) -> LLM:
113
        raise NotImplementedError("MockLLMPredictor does not have an LLM model.")
114

115
    def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
116
        """Mock predict."""
117
        prompt_str = prompt.metadata["prompt_type"]
118
        if prompt_str == PromptType.SUMMARY:
119
            output = _mock_summary_predict(self.max_tokens, prompt_args)
120
        elif prompt_str == PromptType.TREE_INSERT:
121
            output = _mock_insert_predict()
122
        elif prompt_str == PromptType.TREE_SELECT:
123
            output = _mock_query_select()
124
        elif prompt_str == PromptType.TREE_SELECT_MULTIPLE:
125
            output = _mock_query_select_multiple(prompt_args["num_chunks"])
126
        elif prompt_str == PromptType.REFINE:
127
            output = _mock_refine(self.max_tokens, prompt, prompt_args)
128
        elif prompt_str == PromptType.QUESTION_ANSWER:
129
            output = _mock_answer(self.max_tokens, prompt_args)
130
        elif prompt_str == PromptType.KEYWORD_EXTRACT:
131
            output = _mock_keyword_extract(prompt_args)
132
        elif prompt_str == PromptType.QUERY_KEYWORD_EXTRACT:
133
            output = _mock_query_keyword_extract(prompt_args)
134
        elif prompt_str == PromptType.KNOWLEDGE_TRIPLET_EXTRACT:
135
            output = _mock_knowledge_graph_triplet_extract(
136
                prompt_args,
137
                int(prompt.kwargs.get("max_knowledge_triplets", 2)),
138
            )
139
        elif prompt_str == PromptType.CUSTOM:
140
            # we don't know specific prompt type, return generic response
141
            output = ""
142
        else:
143
            raise ValueError("Invalid prompt type.")
144

145
        return output
146

147
    def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen:
148
        raise NotImplementedError
149

150
    async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
151
        return self.predict(prompt, **prompt_args)
152

153
    async def astream(
154
        self, prompt: BasePromptTemplate, **prompt_args: Any
155
    ) -> TokenAsyncGen:
156
        raise NotImplementedError
157

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.