llama-index

Форк
0
142 строки · 4.6 Кб
1
"""Response schema."""
2

3
from dataclasses import dataclass, field
4
from typing import Any, Dict, List, Optional, Union
5

6
from llama_index.legacy.bridge.pydantic import BaseModel
7
from llama_index.legacy.schema import NodeWithScore
8
from llama_index.legacy.types import TokenGen
9
from llama_index.legacy.utils import truncate_text
10

11

12
@dataclass
13
class Response:
14
    """Response object.
15

16
    Returned if streaming=False.
17

18
    Attributes:
19
        response: The response text.
20

21
    """
22

23
    response: Optional[str]
24
    source_nodes: List[NodeWithScore] = field(default_factory=list)
25
    metadata: Optional[Dict[str, Any]] = None
26

27
    def __str__(self) -> str:
28
        """Convert to string representation."""
29
        return self.response or "None"
30

31
    def get_formatted_sources(self, length: int = 100) -> str:
32
        """Get formatted sources text."""
33
        texts = []
34
        for source_node in self.source_nodes:
35
            fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
36
            doc_id = source_node.node.node_id or "None"
37
            source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
38
            texts.append(source_text)
39
        return "\n\n".join(texts)
40

41

42
@dataclass
43
class PydanticResponse:
44
    """PydanticResponse object.
45

46
    Returned if streaming=False.
47

48
    Attributes:
49
        response: The response text.
50

51
    """
52

53
    response: Optional[BaseModel]
54
    source_nodes: List[NodeWithScore] = field(default_factory=list)
55
    metadata: Optional[Dict[str, Any]] = None
56

57
    def __str__(self) -> str:
58
        """Convert to string representation."""
59
        return self.response.json() if self.response else "None"
60

61
    def __getattr__(self, name: str) -> Any:
62
        """Get attribute, but prioritize the pydantic  response object."""
63
        if self.response is not None and name in self.response.dict():
64
            return getattr(self.response, name)
65
        else:
66
            return None
67

68
    def get_formatted_sources(self, length: int = 100) -> str:
69
        """Get formatted sources text."""
70
        texts = []
71
        for source_node in self.source_nodes:
72
            fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
73
            doc_id = source_node.node.node_id or "None"
74
            source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
75
            texts.append(source_text)
76
        return "\n\n".join(texts)
77

78
    def get_response(self) -> Response:
79
        """Get a standard response object."""
80
        response_txt = self.response.json() if self.response else "None"
81
        return Response(response_txt, self.source_nodes, self.metadata)
82

83

84
@dataclass
85
class StreamingResponse:
86
    """StreamingResponse object.
87

88
    Returned if streaming=True.
89

90
    Attributes:
91
        response_gen: The response generator.
92

93
    """
94

95
    response_gen: TokenGen
96
    source_nodes: List[NodeWithScore] = field(default_factory=list)
97
    metadata: Optional[Dict[str, Any]] = None
98
    response_txt: Optional[str] = None
99

100
    def __str__(self) -> str:
101
        """Convert to string representation."""
102
        if self.response_txt is None and self.response_gen is not None:
103
            response_txt = ""
104
            for text in self.response_gen:
105
                response_txt += text
106
            self.response_txt = response_txt
107
        return self.response_txt or "None"
108

109
    def get_response(self) -> Response:
110
        """Get a standard response object."""
111
        if self.response_txt is None and self.response_gen is not None:
112
            response_txt = ""
113
            for text in self.response_gen:
114
                response_txt += text
115
            self.response_txt = response_txt
116
        return Response(self.response_txt, self.source_nodes, self.metadata)
117

118
    def print_response_stream(self) -> None:
119
        """Print the response stream."""
120
        if self.response_txt is None and self.response_gen is not None:
121
            response_txt = ""
122
            for text in self.response_gen:
123
                print(text, end="", flush=True)
124
                response_txt += text
125
            self.response_txt = response_txt
126
        else:
127
            print(self.response_txt)
128

129
    def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str:
130
        """Get formatted sources text."""
131
        texts = []
132
        for source_node in self.source_nodes:
133
            fmt_text_chunk = source_node.node.get_content()
134
            if trim_text:
135
                fmt_text_chunk = truncate_text(fmt_text_chunk, length)
136
            node_id = source_node.node.node_id or "None"
137
            source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}"
138
            texts.append(source_text)
139
        return "\n\n".join(texts)
140

141

142
RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse]
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.