llama-index

Форк
0
332 строки · 13.6 Кб
1
"""SQL Join query engine."""
2

3
import logging
4
from typing import Callable, Dict, Optional, Union
5

6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
8
from llama_index.legacy.core.response.schema import RESPONSE_TYPE, Response
9
from llama_index.legacy.indices.query.query_transform.base import BaseQueryTransform
10
from llama_index.legacy.indices.struct_store.sql_query import (
11
    BaseSQLTableQueryEngine,
12
    NLSQLTableQueryEngine,
13
)
14
from llama_index.legacy.llm_predictor.base import LLMPredictorType
15
from llama_index.legacy.llms.utils import resolve_llm
16
from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate
17
from llama_index.legacy.prompts.mixin import PromptDictType, PromptMixinType
18
from llama_index.legacy.schema import QueryBundle
19
from llama_index.legacy.selectors.llm_selectors import LLMSingleSelector
20
from llama_index.legacy.selectors.pydantic_selectors import PydanticSingleSelector
21
from llama_index.legacy.selectors.utils import get_selector_from_context
22
from llama_index.legacy.service_context import ServiceContext
23
from llama_index.legacy.tools.query_engine import QueryEngineTool
24
from llama_index.legacy.utils import print_text
25

26
logger = logging.getLogger(__name__)
27

28

29
DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT_TMPL = """
30
The original question is given below.
31
This question has been translated into a SQL query. Both the SQL query and \
32
the response are given below.
33
Given the SQL response, the question has also been transformed into a more \
34
detailed query,
35
and executed against another query engine.
36
The transformed query and query engine response are also given below.
37
Given SQL query, SQL response, transformed query, and query engine response, \
38
please synthesize a response to the original question.
39

40
Original question: {query_str}
41
SQL query: {sql_query_str}
42
SQL response: {sql_response_str}
43
Transformed query: {query_engine_query_str}
44
Query engine response: {query_engine_response_str}
45
Response:
46
"""
47
DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT = PromptTemplate(
48
    DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT_TMPL
49
)
50

51

52
DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT_TMPL = """
53
"The original question is given below.
54
This question has been translated into a SQL query. Both the SQL query and the \
55
response are given below.
56
The SQL response either answers the question, or should provide additional context \
57
that can be used to make the question more specific.
58
Your job is to come up with a more specific question that needs to be answered to \
59
fully answer the original question, or 'None' if the original question has already \
60
been fully answered from the SQL response. Do not create a new question that is \
61
irrelevant to the original question; in that case return None instead.
62

63
Examples:
64

65
Original question: Please give more details about the demographics of the city with \
66
the highest population.
67
SQL query: SELECT city, population FROM cities ORDER BY population DESC LIMIT 1
68
SQL response: The city with the highest population is New York City.
69
New question: Can you tell me more about the demographics of New York City?
70

71
Original question: Please compare the sports environment of cities in North America.
72
SQL query: SELECT city_name FROM cities WHERE continent = 'North America' LIMIT 3
73
SQL response: The cities in North America are New York, San Francisco, and Toronto.
74
New question: What sports are played in New York, San Francisco, and Toronto?
75

76
Original question: What is the city with the highest population?
77
SQL query: SELECT city, population FROM cities ORDER BY population DESC LIMIT 1
78
SQL response: The city with the highest population is New York City.
79
New question: None
80

81
Original question: What countries are the top 3 ATP players from?
82
SQL query: SELECT country FROM players WHERE rank <= 3
83
SQL response: The top 3 ATP players are from Serbia, Russia, and Spain.
84
New question: None
85

86
Original question: {query_str}
87
SQL query: {sql_query_str}
88
SQL response: {sql_response_str}
89
New question: "
90
"""
91
DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT = PromptTemplate(
92
    DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT_TMPL
93
)
94

95

96
def _default_check_stop(query_bundle: QueryBundle) -> bool:
97
    """Default check stop function."""
98
    return query_bundle.query_str.lower() == "none"
99

100

101
def _format_sql_query(sql_query: str) -> str:
102
    """Format SQL query."""
103
    return sql_query.replace("\n", " ").replace("\t", " ")
104

105

106
class SQLAugmentQueryTransform(BaseQueryTransform):
107
    """SQL Augment Query Transform.
108

109
    This query transform will transform the query into a more specific query
110
    after augmenting with SQL results.
111

112
    Args:
113
        llm (LLM): LLM to use for query transformation.
114
        sql_augment_transform_prompt (BasePromptTemplate): PromptTemplate to use
115
            for query transformation.
116
        check_stop_parser (Optional[Callable[[str], bool]]): Check stop function.
117

118
    """
119

120
    def __init__(
121
        self,
122
        llm: Optional[LLMPredictorType] = None,
123
        sql_augment_transform_prompt: Optional[BasePromptTemplate] = None,
124
        check_stop_parser: Optional[Callable[[QueryBundle], bool]] = None,
125
    ) -> None:
126
        """Initialize params."""
127
        self._llm = llm or resolve_llm("default")
128

129
        self._sql_augment_transform_prompt = (
130
            sql_augment_transform_prompt or DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT
131
        )
132
        self._check_stop_parser = check_stop_parser or _default_check_stop
133

134
    def _get_prompts(self) -> PromptDictType:
135
        """Get prompts."""
136
        return {"sql_augment_transform_prompt": self._sql_augment_transform_prompt}
137

138
    def _update_prompts(self, prompts: PromptDictType) -> None:
139
        """Update prompts."""
140
        if "sql_augment_transform_prompt" in prompts:
141
            self._sql_augment_transform_prompt = prompts["sql_augment_transform_prompt"]
142

143
    def _run(self, query_bundle: QueryBundle, metadata: Dict) -> QueryBundle:
144
        """Run query transform."""
145
        query_str = query_bundle.query_str
146
        sql_query = metadata["sql_query"]
147
        sql_query_response = metadata["sql_query_response"]
148
        new_query_str = self._llm.predict(
149
            self._sql_augment_transform_prompt,
150
            query_str=query_str,
151
            sql_query_str=sql_query,
152
            sql_response_str=sql_query_response,
153
        )
154
        return QueryBundle(
155
            new_query_str, custom_embedding_strs=query_bundle.custom_embedding_strs
156
        )
157

158
    def check_stop(self, query_bundle: QueryBundle) -> bool:
159
        """Check if query indicates stop."""
160
        return self._check_stop_parser(query_bundle)
161

162

163
class SQLJoinQueryEngine(BaseQueryEngine):
164
    """SQL Join Query Engine.
165

166
    This query engine can "Join" a SQL database results
167
    with another query engine.
168
    It can decide it needs to query the SQL database or the other query engine.
169
    If it decides to query the SQL database, it will first query the SQL database,
170
    whether to augment information with retrieved results from the other query engine.
171

172
    Args:
173
        sql_query_tool (QueryEngineTool): Query engine tool for SQL database.
174
            other_query_tool (QueryEngineTool): Other query engine tool.
175
        selector (Optional[Union[LLMSingleSelector, PydanticSingleSelector]]):
176
            Selector to use.
177
        service_context (Optional[ServiceContext]): Service context to use.
178
        sql_join_synthesis_prompt (Optional[BasePromptTemplate]):
179
            PromptTemplate to use for SQL join synthesis.
180
        sql_augment_query_transform (Optional[SQLAugmentQueryTransform]): Query
181
            transform to use for SQL augmentation.
182
        use_sql_join_synthesis (bool): Whether to use SQL join synthesis.
183
        callback_manager (Optional[CallbackManager]): Callback manager to use.
184
        verbose (bool): Whether to print intermediate results.
185

186
    """
187

188
    def __init__(
189
        self,
190
        sql_query_tool: QueryEngineTool,
191
        other_query_tool: QueryEngineTool,
192
        selector: Optional[Union[LLMSingleSelector, PydanticSingleSelector]] = None,
193
        service_context: Optional[ServiceContext] = None,
194
        sql_join_synthesis_prompt: Optional[BasePromptTemplate] = None,
195
        sql_augment_query_transform: Optional[SQLAugmentQueryTransform] = None,
196
        use_sql_join_synthesis: bool = True,
197
        callback_manager: Optional[CallbackManager] = None,
198
        verbose: bool = True,
199
    ) -> None:
200
        """Initialize params."""
201
        super().__init__(callback_manager=callback_manager)
202
        # validate that the query engines are of the right type
203
        if not isinstance(
204
            sql_query_tool.query_engine,
205
            (BaseSQLTableQueryEngine, NLSQLTableQueryEngine),
206
        ):
207
            raise ValueError(
208
                "sql_query_tool.query_engine must be an instance of "
209
                "BaseSQLTableQueryEngine or NLSQLTableQueryEngine"
210
            )
211
        self._sql_query_tool = sql_query_tool
212
        self._other_query_tool = other_query_tool
213

214
        sql_query_engine = sql_query_tool.query_engine
215
        self._service_context = service_context or sql_query_engine.service_context
216

217
        self._selector = selector or get_selector_from_context(
218
            self._service_context, is_multi=False
219
        )
220
        assert isinstance(self._selector, (LLMSingleSelector, PydanticSingleSelector))
221

222
        self._sql_join_synthesis_prompt = (
223
            sql_join_synthesis_prompt or DEFAULT_SQL_JOIN_SYNTHESIS_PROMPT
224
        )
225
        self._sql_augment_query_transform = (
226
            sql_augment_query_transform
227
            or SQLAugmentQueryTransform(llm=self._service_context.llm)
228
        )
229
        self._use_sql_join_synthesis = use_sql_join_synthesis
230
        self._verbose = verbose
231

232
    def _get_prompt_modules(self) -> PromptMixinType:
233
        """Get prompt sub-modules."""
234
        return {
235
            "selector": self._selector,
236
            "sql_augment_query_transform": self._sql_augment_query_transform,
237
        }
238

239
    def _get_prompts(self) -> PromptDictType:
240
        """Get prompts."""
241
        return {"sql_join_synthesis_prompt": self._sql_join_synthesis_prompt}
242

243
    def _update_prompts(self, prompts: PromptDictType) -> None:
244
        """Update prompts."""
245
        if "sql_join_synthesis_prompt" in prompts:
246
            self._sql_join_synthesis_prompt = prompts["sql_join_synthesis_prompt"]
247

248
    def _query_sql_other(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
249
        """Query SQL database + other query engine in sequence."""
250
        # first query SQL database
251
        sql_response = self._sql_query_tool.query_engine.query(query_bundle)
252
        if not self._use_sql_join_synthesis:
253
            return sql_response
254

255
        sql_query = (
256
            sql_response.metadata["sql_query"] if sql_response.metadata else None
257
        )
258
        if self._verbose:
259
            print_text(f"SQL query: {sql_query}\n", color="yellow")
260
            print_text(f"SQL response: {sql_response}\n", color="yellow")
261

262
        # given SQL db, transform query into new query
263
        new_query = self._sql_augment_query_transform(
264
            query_bundle.query_str,
265
            metadata={
266
                "sql_query": _format_sql_query(sql_query),
267
                "sql_query_response": str(sql_response),
268
            },
269
        )
270

271
        if self._verbose:
272
            print_text(
273
                f"Transformed query given SQL response: {new_query.query_str}\n",
274
                color="blue",
275
            )
276
        logger.info(f"> Transformed query given SQL response: {new_query.query_str}")
277
        if self._sql_augment_query_transform.check_stop(new_query):
278
            return sql_response
279

280
        other_response = self._other_query_tool.query_engine.query(new_query)
281
        if self._verbose:
282
            print_text(f"query engine response: {other_response}\n", color="pink")
283
        logger.info(f"> query engine response: {other_response}")
284

285
        response_str = self._service_context.llm.predict(
286
            self._sql_join_synthesis_prompt,
287
            query_str=query_bundle.query_str,
288
            sql_query_str=sql_query,
289
            sql_response_str=str(sql_response),
290
            query_engine_query_str=new_query.query_str,
291
            query_engine_response_str=str(other_response),
292
        )
293
        if self._verbose:
294
            print_text(f"Final response: {response_str}\n", color="green")
295
        response_metadata = {
296
            **(sql_response.metadata or {}),
297
            **(other_response.metadata or {}),
298
        }
299
        source_nodes = other_response.source_nodes
300
        return Response(
301
            response_str,
302
            metadata=response_metadata,
303
            source_nodes=source_nodes,
304
        )
305

306
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
307
        """Query and get response."""
308
        # TODO: see if this can be consolidated with logic in RouterQueryEngine
309
        metadatas = [self._sql_query_tool.metadata, self._other_query_tool.metadata]
310
        result = self._selector.select(metadatas, query_bundle)
311
        # pick sql query
312
        if result.ind == 0:
313
            if self._verbose:
314
                print_text(f"Querying SQL database: {result.reason}\n", color="blue")
315
            logger.info(f"> Querying SQL database: {result.reason}")
316
            return self._query_sql_other(query_bundle)
317
        elif result.ind == 1:
318
            if self._verbose:
319
                print_text(
320
                    f"Querying other query engine: {result.reason}\n", color="blue"
321
                )
322
            logger.info(f"> Querying other query engine: {result.reason}")
323
            response = self._other_query_tool.query_engine.query(query_bundle)
324
            if self._verbose:
325
                print_text(f"Query Engine response: {response}\n", color="pink")
326
            return response
327
        else:
328
            raise ValueError(f"Invalid result.ind: {result.ind}")
329

330
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
331
        # TODO: make async
332
        return self._query(query_bundle)
333

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.