llama-index

Форк
0
97 строк · 2.8 Кб
1
"""Structured LLM Predictor."""
2

3
import logging
4
from typing import Any, Optional
5

6
from deprecated import deprecated
7

8
from llama_index.legacy.llm_predictor.base import LLMPredictor
9
from llama_index.legacy.prompts.base import BasePromptTemplate
10
from llama_index.legacy.types import TokenGen
11

12
logger = logging.getLogger(__name__)
13

14

15
@deprecated("StructuredLLMPredictor is deprecated. Use llm.structured_predict().")
16
class StructuredLLMPredictor(LLMPredictor):
17
    """Structured LLM predictor class.
18

19
    Args:
20
        llm_predictor (BaseLLMPredictor): LLM Predictor to use.
21

22
    """
23

24
    @classmethod
25
    def class_name(cls) -> str:
26
        return "StructuredLLMPredictor"
27

28
    def predict(
29
        self,
30
        prompt: BasePromptTemplate,
31
        output_cls: Optional[Any] = None,
32
        **prompt_args: Any
33
    ) -> str:
34
        """Predict the answer to a query.
35

36
        Args:
37
            prompt (BasePromptTemplate): BasePromptTemplate to use for prediction.
38

39
        Returns:
40
            Tuple[str, str]: Tuple of the predicted answer and the formatted prompt.
41

42
        """
43
        llm_prediction = super().predict(prompt, **prompt_args)
44
        # run output parser
45
        if prompt.output_parser is not None:
46
            # TODO: return other formats
47
            output_parser = prompt.output_parser
48
            parsed_llm_prediction = str(output_parser.parse(llm_prediction))
49
        else:
50
            parsed_llm_prediction = llm_prediction
51

52
        return parsed_llm_prediction
53

54
    def stream(
55
        self,
56
        prompt: BasePromptTemplate,
57
        output_cls: Optional[Any] = None,
58
        **prompt_args: Any
59
    ) -> TokenGen:
60
        """Stream the answer to a query.
61

62
        NOTE: this is a beta feature. Will try to build or use
63
        better abstractions about response handling.
64

65
        Args:
66
            prompt (BasePromptTemplate): BasePromptTemplate to use for prediction.
67

68
        Returns:
69
            str: The predicted answer.
70

71
        """
72
        raise NotImplementedError(
73
            "Streaming is not supported for structured LLM predictor."
74
        )
75

76
    async def apredict(
77
        self,
78
        prompt: BasePromptTemplate,
79
        output_cls: Optional[Any] = None,
80
        **prompt_args: Any
81
    ) -> str:
82
        """Async predict the answer to a query.
83

84
        Args:
85
            prompt (BasePromptTemplate): BasePromptTemplate to use for prediction.
86

87
        Returns:
88
            Tuple[str, str]: Tuple of the predicted answer and the formatted prompt.
89

90
        """
91
        llm_prediction = await super().apredict(prompt, **prompt_args)
92
        if prompt.output_parser is not None:
93
            output_parser = prompt.output_parser
94
            parsed_llm_prediction = str(output_parser.parse(llm_prediction))
95
        else:
96
            parsed_llm_prediction = llm_prediction
97
        return parsed_llm_prediction
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.