llama-index

Форк
0
98 строк · 3.3 Кб
1
from typing import Any, Generator, Optional, Sequence, cast
2

3
from llama_index.legacy.prompts import BasePromptTemplate
4
from llama_index.legacy.prompts.default_prompt_selectors import (
5
    DEFAULT_TEXT_QA_PROMPT_SEL,
6
)
7
from llama_index.legacy.prompts.mixin import PromptDictType
8
from llama_index.legacy.response_synthesizers.base import BaseSynthesizer
9
from llama_index.legacy.service_context import ServiceContext
10
from llama_index.legacy.types import RESPONSE_TEXT_TYPE
11

12

13
class SimpleSummarize(BaseSynthesizer):
14
    def __init__(
15
        self,
16
        text_qa_template: Optional[BasePromptTemplate] = None,
17
        service_context: Optional[ServiceContext] = None,
18
        streaming: bool = False,
19
    ) -> None:
20
        super().__init__(service_context=service_context, streaming=streaming)
21
        self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
22

23
    def _get_prompts(self) -> PromptDictType:
24
        """Get prompts."""
25
        return {"text_qa_template": self._text_qa_template}
26

27
    def _update_prompts(self, prompts: PromptDictType) -> None:
28
        """Update prompts."""
29
        if "text_qa_template" in prompts:
30
            self._text_qa_template = prompts["text_qa_template"]
31

32
    async def aget_response(
33
        self,
34
        query_str: str,
35
        text_chunks: Sequence[str],
36
        **response_kwargs: Any,
37
    ) -> RESPONSE_TEXT_TYPE:
38
        text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
39
        truncated_chunks = self._service_context.prompt_helper.truncate(
40
            prompt=text_qa_template,
41
            text_chunks=text_chunks,
42
        )
43
        node_text = "\n".join(truncated_chunks)
44

45
        response: RESPONSE_TEXT_TYPE
46
        if not self._streaming:
47
            response = await self._service_context.llm.apredict(
48
                text_qa_template,
49
                context_str=node_text,
50
                **response_kwargs,
51
            )
52
        else:
53
            response = self._service_context.llm.stream(
54
                text_qa_template,
55
                context_str=node_text,
56
                **response_kwargs,
57
            )
58

59
        if isinstance(response, str):
60
            response = response or "Empty Response"
61
        else:
62
            response = cast(Generator, response)
63

64
        return response
65

66
    def get_response(
67
        self,
68
        query_str: str,
69
        text_chunks: Sequence[str],
70
        **kwargs: Any,
71
    ) -> RESPONSE_TEXT_TYPE:
72
        text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
73
        truncated_chunks = self._service_context.prompt_helper.truncate(
74
            prompt=text_qa_template,
75
            text_chunks=text_chunks,
76
        )
77
        node_text = "\n".join(truncated_chunks)
78

79
        response: RESPONSE_TEXT_TYPE
80
        if not self._streaming:
81
            response = self._service_context.llm.predict(
82
                text_qa_template,
83
                context_str=node_text,
84
                **kwargs,
85
            )
86
        else:
87
            response = self._service_context.llm.stream(
88
                text_qa_template,
89
                context_str=node_text,
90
                **kwargs,
91
            )
92

93
        if isinstance(response, str):
94
            response = response or "Empty Response"
95
        else:
96
            response = cast(Generator, response)
97

98
        return response
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.