llama-index
142 строки · 4.6 Кб
1"""Response schema."""
2
3from dataclasses import dataclass, field
4from typing import Any, Dict, List, Optional, Union
5
6from llama_index.legacy.bridge.pydantic import BaseModel
7from llama_index.legacy.schema import NodeWithScore
8from llama_index.legacy.types import TokenGen
9from llama_index.legacy.utils import truncate_text
10
11
12@dataclass
13class Response:
14"""Response object.
15
16Returned if streaming=False.
17
18Attributes:
19response: The response text.
20
21"""
22
23response: Optional[str]
24source_nodes: List[NodeWithScore] = field(default_factory=list)
25metadata: Optional[Dict[str, Any]] = None
26
27def __str__(self) -> str:
28"""Convert to string representation."""
29return self.response or "None"
30
31def get_formatted_sources(self, length: int = 100) -> str:
32"""Get formatted sources text."""
33texts = []
34for source_node in self.source_nodes:
35fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
36doc_id = source_node.node.node_id or "None"
37source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
38texts.append(source_text)
39return "\n\n".join(texts)
40
41
42@dataclass
43class PydanticResponse:
44"""PydanticResponse object.
45
46Returned if streaming=False.
47
48Attributes:
49response: The response text.
50
51"""
52
53response: Optional[BaseModel]
54source_nodes: List[NodeWithScore] = field(default_factory=list)
55metadata: Optional[Dict[str, Any]] = None
56
57def __str__(self) -> str:
58"""Convert to string representation."""
59return self.response.json() if self.response else "None"
60
61def __getattr__(self, name: str) -> Any:
62"""Get attribute, but prioritize the pydantic response object."""
63if self.response is not None and name in self.response.dict():
64return getattr(self.response, name)
65else:
66return None
67
68def get_formatted_sources(self, length: int = 100) -> str:
69"""Get formatted sources text."""
70texts = []
71for source_node in self.source_nodes:
72fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
73doc_id = source_node.node.node_id or "None"
74source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
75texts.append(source_text)
76return "\n\n".join(texts)
77
78def get_response(self) -> Response:
79"""Get a standard response object."""
80response_txt = self.response.json() if self.response else "None"
81return Response(response_txt, self.source_nodes, self.metadata)
82
83
84@dataclass
85class StreamingResponse:
86"""StreamingResponse object.
87
88Returned if streaming=True.
89
90Attributes:
91response_gen: The response generator.
92
93"""
94
95response_gen: TokenGen
96source_nodes: List[NodeWithScore] = field(default_factory=list)
97metadata: Optional[Dict[str, Any]] = None
98response_txt: Optional[str] = None
99
100def __str__(self) -> str:
101"""Convert to string representation."""
102if self.response_txt is None and self.response_gen is not None:
103response_txt = ""
104for text in self.response_gen:
105response_txt += text
106self.response_txt = response_txt
107return self.response_txt or "None"
108
109def get_response(self) -> Response:
110"""Get a standard response object."""
111if self.response_txt is None and self.response_gen is not None:
112response_txt = ""
113for text in self.response_gen:
114response_txt += text
115self.response_txt = response_txt
116return Response(self.response_txt, self.source_nodes, self.metadata)
117
118def print_response_stream(self) -> None:
119"""Print the response stream."""
120if self.response_txt is None and self.response_gen is not None:
121response_txt = ""
122for text in self.response_gen:
123print(text, end="", flush=True)
124response_txt += text
125self.response_txt = response_txt
126else:
127print(self.response_txt)
128
129def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str:
130"""Get formatted sources text."""
131texts = []
132for source_node in self.source_nodes:
133fmt_text_chunk = source_node.node.get_content()
134if trim_text:
135fmt_text_chunk = truncate_text(fmt_text_chunk, length)
136node_id = source_node.node.node_id or "None"
137source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}"
138texts.append(source_text)
139return "\n\n".join(texts)
140
141
142RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse]
143