llama-index

Форк
0
345 строк · 12.6 Кб
1
import asyncio
2
import json
3
import logging
4
from typing import Any, Callable, Dict, List, Optional, Tuple
5

6
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
7
from llama_index.legacy.core.response.schema import Response
8
from llama_index.legacy.indices.struct_store.sql_retriever import (
9
    BaseSQLParser,
10
    DefaultSQLParser,
11
)
12
from llama_index.legacy.prompts import BasePromptTemplate, PromptTemplate
13
from llama_index.legacy.prompts.default_prompts import DEFAULT_JSONALYZE_PROMPT
14
from llama_index.legacy.prompts.mixin import PromptDictType, PromptMixinType
15
from llama_index.legacy.prompts.prompt_type import PromptType
16
from llama_index.legacy.schema import QueryBundle
17
from llama_index.legacy.service_context import ServiceContext
18
from llama_index.legacy.utils import print_text
19

20
logger = logging.getLogger(__name__)
21

22
DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = (
23
    "Given a query, synthesize a response based on SQL query results"
24
    " to satisfy the query. Only include details that are relevant to"
25
    " the query. If you don't know the answer, then say that.\n"
26
    "SQL Query: {sql_query}\n"
27
    "Table Schema: {table_schema}\n"
28
    "SQL Response: {sql_response}\n"
29
    "Query: {query_str}\n"
30
    "Response: "
31
)
32

33
DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate(
34
    DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL,
35
    prompt_type=PromptType.SQL_RESPONSE_SYNTHESIS,
36
)
37

38
DEFAULT_TABLE_NAME = "items"
39

40

41
def default_jsonalyzer(
42
    list_of_dict: List[Dict[str, Any]],
43
    query_bundle: QueryBundle,
44
    service_context: ServiceContext,
45
    table_name: str = DEFAULT_TABLE_NAME,
46
    prompt: BasePromptTemplate = DEFAULT_JSONALYZE_PROMPT,
47
    sql_parser: BaseSQLParser = DefaultSQLParser(),
48
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
49
    """Default JSONalyzer that executes a query on a list of dictionaries.
50

51
    Args:
52
        list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
53
        query_bundle (QueryBundle): The query bundle.
54
        service_context (Optional[ServiceContext]): The service context.
55
        table_name (str): The table name to use, defaults to DEFAULT_TABLE_NAME.
56
        prompt (BasePromptTemplate): The prompt to use.
57
        sql_parser (BaseSQLParser): The SQL parser to use.
58

59
    Returns:
60
        Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
61
            the Schema, and the Result.
62
    """
63
    try:
64
        import sqlite_utils
65
    except ImportError as exc:
66
        IMPORT_ERROR_MSG = (
67
            "sqlite-utils is needed to use this Query Engine:\n"
68
            "pip install sqlite-utils"
69
        )
70

71
        raise ImportError(IMPORT_ERROR_MSG) from exc
72
    # Instantiate in-memory SQLite database
73
    db = sqlite_utils.Database(memory=True)
74
    try:
75
        # Load list of dictionaries into SQLite database
76
        db[table_name].insert_all(list_of_dict)
77
    except sqlite_utils.db_exceptions.IntegrityError as exc:
78
        print_text(f"Error inserting into table {table_name}, expected format:")
79
        print_text("[{col1: val1, col2: val2, ...}, ...]")
80
        raise ValueError("Invalid list_of_dict") from exc
81

82
    # Get the table schema
83
    table_schema = db[table_name].columns_dict
84

85
    query = query_bundle.query_str
86
    prompt = prompt or DEFAULT_JSONALYZE_PROMPT
87
    # Get the SQL query with text-to-SQL prompt
88
    response_str = service_context.llm.predict(
89
        prompt=prompt,
90
        table_name=table_name,
91
        table_schema=table_schema,
92
        question=query,
93
    )
94

95
    sql_parser = sql_parser or DefaultSQLParser()
96

97
    sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)
98

99
    try:
100
        # Execute the SQL query
101
        results = list(db.query(sql_query))
102
    except sqlite_utils.db_exceptions.OperationalError as exc:
103
        print_text(f"Error executing query: {sql_query}")
104
        raise ValueError("Invalid query") from exc
105

106
    return sql_query, table_schema, results
107

108

109
async def async_default_jsonalyzer(
110
    list_of_dict: List[Dict[str, Any]],
111
    query_bundle: QueryBundle,
112
    service_context: ServiceContext,
113
    prompt: Optional[BasePromptTemplate] = None,
114
    sql_parser: Optional[BaseSQLParser] = None,
115
    table_name: str = DEFAULT_TABLE_NAME,
116
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
117
    """Default JSONalyzer.
118

119
    Args:
120
        list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
121
        query_bundle (QueryBundle): The query bundle.
122
        service_context (ServiceContext): ServiceContext
123
        prompt (BasePromptTemplate, optional): The prompt to use.
124
        sql_parser (BaseSQLParser, optional): The SQL parser to use.
125
        table_name (str, optional): The table name to use, defaults to DEFAULT_TABLE_NAME.
126

127
    Returns:
128
        Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
129
            the Schema, and the Result.
130
    """
131
    try:
132
        import sqlite_utils
133
    except ImportError as exc:
134
        IMPORT_ERROR_MSG = (
135
            "sqlite-utils is needed to use this Query Engine:\n"
136
            "pip install sqlite-utils"
137
        )
138

139
        raise ImportError(IMPORT_ERROR_MSG) from exc
140
    # Instantiate in-memory SQLite database
141
    db = sqlite_utils.Database(memory=True)
142
    try:
143
        # Load list of dictionaries into SQLite database
144
        db[table_name].insert_all(list_of_dict)
145
    except sqlite_utils.db_exceptions.IntegrityError as exc:
146
        print_text(f"Error inserting into table {table_name}, expected format:")
147
        print_text("[{col1: val1, col2: val2, ...}, ...]")
148
        raise ValueError("Invalid list_of_dict") from exc
149

150
    # Get the table schema
151
    table_schema = db[table_name].columns_dict
152

153
    query = query_bundle.query_str
154
    prompt = prompt or DEFAULT_JSONALYZE_PROMPT
155
    # Get the SQL query with text-to-SQL prompt
156
    response_str = await service_context.llm.apredict(
157
        prompt=prompt,
158
        table_name=table_name,
159
        table_schema=table_schema,
160
        question=query,
161
    )
162

163
    sql_parser = sql_parser or DefaultSQLParser()
164

165
    sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)
166

167
    try:
168
        # Execute the SQL query
169
        results = list(db.query(sql_query))
170
    except sqlite_utils.db_exceptions.OperationalError as exc:
171
        print_text(f"Error executing query: {sql_query}")
172
        raise ValueError("Invalid query") from exc
173

174
    return sql_query, table_schema, results
175

176

177
def load_jsonalyzer(
178
    use_async: bool = False,
179
    custom_jsonalyzer: Optional[Callable] = None,
180
) -> Callable:
181
    """Load the JSONalyzer.
182

183
    Args:
184
        use_async (bool): Whether to use async.
185
        custom_jsonalyzer (Callable): A custom JSONalyzer to use.
186

187
    Returns:
188
        Callable: The JSONalyzer.
189
    """
190
    if custom_jsonalyzer:
191
        assert not use_async or asyncio.iscoroutinefunction(
192
            custom_jsonalyzer
193
        ), "custom_jsonalyzer function must be async when use_async is True"
194
        return custom_jsonalyzer
195
    else:
196
        # make mypy happy to indent this
197
        if use_async:
198
            return async_default_jsonalyzer
199
        else:
200
            return default_jsonalyzer
201

202

203
class JSONalyzeQueryEngine(BaseQueryEngine):
204
    """JSON List Shape Data Analysis Query Engine.
205

206
    Converts natural language statasical queries to SQL within in-mem SQLite queries.
207

208
    list_of_dict(List[Dict[str, Any]]): List of dictionaries to query.
209
    service_context (ServiceContext): ServiceContext
210
    jsonalyze_prompt (BasePromptTemplate): The JSONalyze prompt to use.
211
    use_async (bool): Whether to use async.
212
    analyzer (Callable): The analyzer that executes the query.
213
    sql_parser (BaseSQLParser): The SQL parser that ensures valid SQL being parsed
214
        from llm output.
215
    synthesize_response (bool): Whether to synthesize a response.
216
    response_synthesis_prompt (BasePromptTemplate): The response synthesis prompt
217
        to use.
218
    table_name (str): The table name to use.
219
    verbose (bool): Whether to print verbose output.
220
    """
221

222
    def __init__(
223
        self,
224
        list_of_dict: List[Dict[str, Any]],
225
        service_context: ServiceContext,
226
        jsonalyze_prompt: Optional[BasePromptTemplate] = None,
227
        use_async: bool = False,
228
        analyzer: Optional[Callable] = None,
229
        sql_parser: Optional[BaseSQLParser] = None,
230
        synthesize_response: bool = True,
231
        response_synthesis_prompt: Optional[BasePromptTemplate] = None,
232
        table_name: str = DEFAULT_TABLE_NAME,
233
        verbose: bool = False,
234
        **kwargs: Any,
235
    ) -> None:
236
        """Initialize params."""
237
        self._list_of_dict = list_of_dict
238
        self._service_context = service_context or ServiceContext.from_defaults()
239
        self._jsonalyze_prompt = jsonalyze_prompt or DEFAULT_JSONALYZE_PROMPT
240
        self._use_async = use_async
241
        self._analyzer = load_jsonalyzer(use_async, analyzer)
242
        self._sql_parser = sql_parser or DefaultSQLParser()
243
        self._synthesize_response = synthesize_response
244
        self._response_synthesis_prompt = (
245
            response_synthesis_prompt or DEFAULT_RESPONSE_SYNTHESIS_PROMPT
246
        )
247
        self._table_name = table_name
248
        self._verbose = verbose
249

250
        super().__init__(self._service_context.callback_manager)
251

252
    def _get_prompts(self) -> Dict[str, Any]:
253
        """Get prompts."""
254
        return {
255
            "jsonalyze_prompt": self._jsonalyze_prompt,
256
            "response_synthesis_prompt": self._response_synthesis_prompt,
257
        }
258

259
    def _update_prompts(self, prompts: PromptDictType) -> None:
260
        """Update prompts."""
261
        if "jsonalyze_prompt" in prompts:
262
            self._jsonalyze_prompt = prompts["jsonalyze_prompt"]
263
        if "response_synthesis_prompt" in prompts:
264
            self._response_synthesis_prompt = prompts["response_synthesis_prompt"]
265

266
    def _get_prompt_modules(self) -> PromptMixinType:
267
        """Get prompt sub-modules."""
268
        return {}
269

270
    def _query(self, query_bundle: QueryBundle) -> Response:
271
        """Answer an analytical query on the JSON List."""
272
        query = query_bundle.query_str
273
        if self._verbose:
274
            print_text(f"Query: {query}\n", color="green")
275

276
        # Perform the analysis
277
        sql_query, table_schema, results = self._analyzer(
278
            self._list_of_dict,
279
            query_bundle,
280
            self._service_context,
281
            table_name=self._table_name,
282
            prompt=self._jsonalyze_prompt,
283
            sql_parser=self._sql_parser,
284
        )
285
        if self._verbose:
286
            print_text(f"SQL Query: {sql_query}\n", color="blue")
287
            print_text(f"Table Schema: {table_schema}\n", color="cyan")
288
            print_text(f"SQL Response: {results}\n", color="yellow")
289

290
        if self._synthesize_response:
291
            response_str = self._service_context.llm.predict(
292
                self._response_synthesis_prompt,
293
                sql_query=sql_query,
294
                table_schema=table_schema,
295
                sql_response=results,
296
                query_str=query_bundle.query_str,
297
            )
298
            if self._verbose:
299
                print_text(f"Response: {response_str}", color="magenta")
300
        else:
301
            response_str = str(results)
302
        response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}
303

304
        return Response(response=response_str, metadata=response_metadata)
305

306
    async def _aquery(self, query_bundle: QueryBundle) -> Response:
307
        """Answer an analytical query on the JSON List."""
308
        query = query_bundle.query_str
309
        if self._verbose:
310
            print_text(f"Query: {query}", color="green")
311

312
        # Perform the analysis
313
        sql_query, table_schema, results = self._analyzer(
314
            self._list_of_dict,
315
            query,
316
            self._service_context,
317
            table_name=self._table_name,
318
            prompt=self._jsonalyze_prompt,
319
        )
320
        if self._verbose:
321
            print_text(f"SQL Query: {sql_query}\n", color="blue")
322
            print_text(f"Table Schema: {table_schema}\n", color="cyan")
323
            print_text(f"SQL Response: {results}\n", color="yellow")
324

325
        if self._synthesize_response:
326
            response_str = await self._service_context.llm.apredict(
327
                self._response_synthesis_prompt,
328
                sql_query=sql_query,
329
                table_schema=table_schema,
330
                sql_response=results,
331
                query_str=query_bundle.query_str,
332
            )
333
            if self._verbose:
334
                print_text(f"Response: {response_str}", color="magenta")
335
        else:
336
            response_str = json.dumps(
337
                {
338
                    "sql_query": sql_query,
339
                    "table_schema": table_schema,
340
                    "sql_response": results,
341
                }
342
            )
343
        response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}
344

345
        return Response(response=response_str, metadata=response_metadata)
346

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.