llama-index
345 строк · 12.6 Кб
1import asyncio2import json3import logging4from typing import Any, Callable, Dict, List, Optional, Tuple5
6from llama_index.legacy.core.base_query_engine import BaseQueryEngine7from llama_index.legacy.core.response.schema import Response8from llama_index.legacy.indices.struct_store.sql_retriever import (9BaseSQLParser,10DefaultSQLParser,11)
12from llama_index.legacy.prompts import BasePromptTemplate, PromptTemplate13from llama_index.legacy.prompts.default_prompts import DEFAULT_JSONALYZE_PROMPT14from llama_index.legacy.prompts.mixin import PromptDictType, PromptMixinType15from llama_index.legacy.prompts.prompt_type import PromptType16from llama_index.legacy.schema import QueryBundle17from llama_index.legacy.service_context import ServiceContext18from llama_index.legacy.utils import print_text19
20logger = logging.getLogger(__name__)21
22DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = (23"Given a query, synthesize a response based on SQL query results"24" to satisfy the query. Only include details that are relevant to"25" the query. If you don't know the answer, then say that.\n"26"SQL Query: {sql_query}\n"27"Table Schema: {table_schema}\n"28"SQL Response: {sql_response}\n"29"Query: {query_str}\n"30"Response: "31)
32
33DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate(34DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL,35prompt_type=PromptType.SQL_RESPONSE_SYNTHESIS,36)
37
38DEFAULT_TABLE_NAME = "items"39
40
41def default_jsonalyzer(42list_of_dict: List[Dict[str, Any]],43query_bundle: QueryBundle,44service_context: ServiceContext,45table_name: str = DEFAULT_TABLE_NAME,46prompt: BasePromptTemplate = DEFAULT_JSONALYZE_PROMPT,47sql_parser: BaseSQLParser = DefaultSQLParser(),48) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:49"""Default JSONalyzer that executes a query on a list of dictionaries.50
51Args:
52list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
53query_bundle (QueryBundle): The query bundle.
54service_context (Optional[ServiceContext]): The service context.
55table_name (str): The table name to use, defaults to DEFAULT_TABLE_NAME.
56prompt (BasePromptTemplate): The prompt to use.
57sql_parser (BaseSQLParser): The SQL parser to use.
58
59Returns:
60Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
61the Schema, and the Result.
62"""
63try:64import sqlite_utils65except ImportError as exc:66IMPORT_ERROR_MSG = (67"sqlite-utils is needed to use this Query Engine:\n"68"pip install sqlite-utils"69)70
71raise ImportError(IMPORT_ERROR_MSG) from exc72# Instantiate in-memory SQLite database73db = sqlite_utils.Database(memory=True)74try:75# Load list of dictionaries into SQLite database76db[table_name].insert_all(list_of_dict)77except sqlite_utils.db_exceptions.IntegrityError as exc:78print_text(f"Error inserting into table {table_name}, expected format:")79print_text("[{col1: val1, col2: val2, ...}, ...]")80raise ValueError("Invalid list_of_dict") from exc81
82# Get the table schema83table_schema = db[table_name].columns_dict84
85query = query_bundle.query_str86prompt = prompt or DEFAULT_JSONALYZE_PROMPT87# Get the SQL query with text-to-SQL prompt88response_str = service_context.llm.predict(89prompt=prompt,90table_name=table_name,91table_schema=table_schema,92question=query,93)94
95sql_parser = sql_parser or DefaultSQLParser()96
97sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)98
99try:100# Execute the SQL query101results = list(db.query(sql_query))102except sqlite_utils.db_exceptions.OperationalError as exc:103print_text(f"Error executing query: {sql_query}")104raise ValueError("Invalid query") from exc105
106return sql_query, table_schema, results107
108
109async def async_default_jsonalyzer(110list_of_dict: List[Dict[str, Any]],111query_bundle: QueryBundle,112service_context: ServiceContext,113prompt: Optional[BasePromptTemplate] = None,114sql_parser: Optional[BaseSQLParser] = None,115table_name: str = DEFAULT_TABLE_NAME,116) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:117"""Default JSONalyzer.118
119Args:
120list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
121query_bundle (QueryBundle): The query bundle.
122service_context (ServiceContext): ServiceContext
123prompt (BasePromptTemplate, optional): The prompt to use.
124sql_parser (BaseSQLParser, optional): The SQL parser to use.
125table_name (str, optional): The table name to use, defaults to DEFAULT_TABLE_NAME.
126
127Returns:
128Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
129the Schema, and the Result.
130"""
131try:132import sqlite_utils133except ImportError as exc:134IMPORT_ERROR_MSG = (135"sqlite-utils is needed to use this Query Engine:\n"136"pip install sqlite-utils"137)138
139raise ImportError(IMPORT_ERROR_MSG) from exc140# Instantiate in-memory SQLite database141db = sqlite_utils.Database(memory=True)142try:143# Load list of dictionaries into SQLite database144db[table_name].insert_all(list_of_dict)145except sqlite_utils.db_exceptions.IntegrityError as exc:146print_text(f"Error inserting into table {table_name}, expected format:")147print_text("[{col1: val1, col2: val2, ...}, ...]")148raise ValueError("Invalid list_of_dict") from exc149
150# Get the table schema151table_schema = db[table_name].columns_dict152
153query = query_bundle.query_str154prompt = prompt or DEFAULT_JSONALYZE_PROMPT155# Get the SQL query with text-to-SQL prompt156response_str = await service_context.llm.apredict(157prompt=prompt,158table_name=table_name,159table_schema=table_schema,160question=query,161)162
163sql_parser = sql_parser or DefaultSQLParser()164
165sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)166
167try:168# Execute the SQL query169results = list(db.query(sql_query))170except sqlite_utils.db_exceptions.OperationalError as exc:171print_text(f"Error executing query: {sql_query}")172raise ValueError("Invalid query") from exc173
174return sql_query, table_schema, results175
176
177def load_jsonalyzer(178use_async: bool = False,179custom_jsonalyzer: Optional[Callable] = None,180) -> Callable:181"""Load the JSONalyzer.182
183Args:
184use_async (bool): Whether to use async.
185custom_jsonalyzer (Callable): A custom JSONalyzer to use.
186
187Returns:
188Callable: The JSONalyzer.
189"""
190if custom_jsonalyzer:191assert not use_async or asyncio.iscoroutinefunction(192custom_jsonalyzer
193), "custom_jsonalyzer function must be async when use_async is True"194return custom_jsonalyzer195else:196# make mypy happy to indent this197if use_async:198return async_default_jsonalyzer199else:200return default_jsonalyzer201
202
203class JSONalyzeQueryEngine(BaseQueryEngine):204"""JSON List Shape Data Analysis Query Engine.205
206Converts natural language statasical queries to SQL within in-mem SQLite queries.
207
208list_of_dict(List[Dict[str, Any]]): List of dictionaries to query.
209service_context (ServiceContext): ServiceContext
210jsonalyze_prompt (BasePromptTemplate): The JSONalyze prompt to use.
211use_async (bool): Whether to use async.
212analyzer (Callable): The analyzer that executes the query.
213sql_parser (BaseSQLParser): The SQL parser that ensures valid SQL being parsed
214from llm output.
215synthesize_response (bool): Whether to synthesize a response.
216response_synthesis_prompt (BasePromptTemplate): The response synthesis prompt
217to use.
218table_name (str): The table name to use.
219verbose (bool): Whether to print verbose output.
220"""
221
222def __init__(223self,224list_of_dict: List[Dict[str, Any]],225service_context: ServiceContext,226jsonalyze_prompt: Optional[BasePromptTemplate] = None,227use_async: bool = False,228analyzer: Optional[Callable] = None,229sql_parser: Optional[BaseSQLParser] = None,230synthesize_response: bool = True,231response_synthesis_prompt: Optional[BasePromptTemplate] = None,232table_name: str = DEFAULT_TABLE_NAME,233verbose: bool = False,234**kwargs: Any,235) -> None:236"""Initialize params."""237self._list_of_dict = list_of_dict238self._service_context = service_context or ServiceContext.from_defaults()239self._jsonalyze_prompt = jsonalyze_prompt or DEFAULT_JSONALYZE_PROMPT240self._use_async = use_async241self._analyzer = load_jsonalyzer(use_async, analyzer)242self._sql_parser = sql_parser or DefaultSQLParser()243self._synthesize_response = synthesize_response244self._response_synthesis_prompt = (245response_synthesis_prompt or DEFAULT_RESPONSE_SYNTHESIS_PROMPT246)247self._table_name = table_name248self._verbose = verbose249
250super().__init__(self._service_context.callback_manager)251
252def _get_prompts(self) -> Dict[str, Any]:253"""Get prompts."""254return {255"jsonalyze_prompt": self._jsonalyze_prompt,256"response_synthesis_prompt": self._response_synthesis_prompt,257}258
259def _update_prompts(self, prompts: PromptDictType) -> None:260"""Update prompts."""261if "jsonalyze_prompt" in prompts:262self._jsonalyze_prompt = prompts["jsonalyze_prompt"]263if "response_synthesis_prompt" in prompts:264self._response_synthesis_prompt = prompts["response_synthesis_prompt"]265
266def _get_prompt_modules(self) -> PromptMixinType:267"""Get prompt sub-modules."""268return {}269
270def _query(self, query_bundle: QueryBundle) -> Response:271"""Answer an analytical query on the JSON List."""272query = query_bundle.query_str273if self._verbose:274print_text(f"Query: {query}\n", color="green")275
276# Perform the analysis277sql_query, table_schema, results = self._analyzer(278self._list_of_dict,279query_bundle,280self._service_context,281table_name=self._table_name,282prompt=self._jsonalyze_prompt,283sql_parser=self._sql_parser,284)285if self._verbose:286print_text(f"SQL Query: {sql_query}\n", color="blue")287print_text(f"Table Schema: {table_schema}\n", color="cyan")288print_text(f"SQL Response: {results}\n", color="yellow")289
290if self._synthesize_response:291response_str = self._service_context.llm.predict(292self._response_synthesis_prompt,293sql_query=sql_query,294table_schema=table_schema,295sql_response=results,296query_str=query_bundle.query_str,297)298if self._verbose:299print_text(f"Response: {response_str}", color="magenta")300else:301response_str = str(results)302response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}303
304return Response(response=response_str, metadata=response_metadata)305
306async def _aquery(self, query_bundle: QueryBundle) -> Response:307"""Answer an analytical query on the JSON List."""308query = query_bundle.query_str309if self._verbose:310print_text(f"Query: {query}", color="green")311
312# Perform the analysis313sql_query, table_schema, results = self._analyzer(314self._list_of_dict,315query,316self._service_context,317table_name=self._table_name,318prompt=self._jsonalyze_prompt,319)320if self._verbose:321print_text(f"SQL Query: {sql_query}\n", color="blue")322print_text(f"Table Schema: {table_schema}\n", color="cyan")323print_text(f"SQL Response: {results}\n", color="yellow")324
325if self._synthesize_response:326response_str = await self._service_context.llm.apredict(327self._response_synthesis_prompt,328sql_query=sql_query,329table_schema=table_schema,330sql_response=results,331query_str=query_bundle.query_str,332)333if self._verbose:334print_text(f"Response: {response_str}", color="magenta")335else:336response_str = json.dumps(337{338"sql_query": sql_query,339"table_schema": table_schema,340"sql_response": results,341}342)343response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}344
345return Response(response=response_str, metadata=response_metadata)346