llama-index
136 строк · 5.3 Кб
1import logging2from typing import Optional3
4from llama_index.legacy.callbacks.base import CallbackManager5from llama_index.legacy.core.base_query_engine import BaseQueryEngine6from llama_index.legacy.core.response.schema import RESPONSE_TYPE, Response7from llama_index.legacy.evaluation.base import BaseEvaluator8from llama_index.legacy.evaluation.guideline import GuidelineEvaluator9from llama_index.legacy.indices.query.query_transform.feedback_transform import (10FeedbackQueryTransformation,11)
12from llama_index.legacy.prompts.mixin import PromptMixinType13from llama_index.legacy.schema import QueryBundle14
15logger = logging.getLogger(__name__)16
17
18class RetryQueryEngine(BaseQueryEngine):19"""Does retry on query engine if it fails evaluation.20
21Args:
22query_engine (BaseQueryEngine): A query engine object
23evaluator (BaseEvaluator): An evaluator object
24max_retries (int): Maximum number of retries
25callback_manager (Optional[CallbackManager]): A callback manager object
26"""
27
28def __init__(29self,30query_engine: BaseQueryEngine,31evaluator: BaseEvaluator,32max_retries: int = 3,33callback_manager: Optional[CallbackManager] = None,34) -> None:35self._query_engine = query_engine36self._evaluator = evaluator37self.max_retries = max_retries38super().__init__(callback_manager)39
40def _get_prompt_modules(self) -> PromptMixinType:41"""Get prompt sub-modules."""42return {"query_engine": self._query_engine, "evaluator": self._evaluator}43
44def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:45"""Answer a query."""46response = self._query_engine._query(query_bundle)47if self.max_retries <= 0:48return response49typed_response = (50response if isinstance(response, Response) else response.get_response()51)52query_str = query_bundle.query_str53eval = self._evaluator.evaluate_response(query_str, typed_response)54if eval.passing:55logger.debug("Evaluation returned True.")56return response57else:58logger.debug("Evaluation returned False.")59new_query_engine = RetryQueryEngine(60self._query_engine, self._evaluator, self.max_retries - 161)62query_transformer = FeedbackQueryTransformation()63new_query = query_transformer.run(query_bundle, {"evaluation": eval})64return new_query_engine.query(new_query)65
66async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:67"""Not supported."""68return self._query(query_bundle)69
70
71class RetryGuidelineQueryEngine(BaseQueryEngine):72"""Does retry with evaluator feedback73if query engine fails evaluation.
74
75Args:
76query_engine (BaseQueryEngine): A query engine object
77guideline_evaluator (GuidelineEvaluator): A guideline evaluator object
78resynthesize_query (bool): Whether to resynthesize query
79max_retries (int): Maximum number of retries
80callback_manager (Optional[CallbackManager]): A callback manager object
81"""
82
83def __init__(84self,85query_engine: BaseQueryEngine,86guideline_evaluator: GuidelineEvaluator,87resynthesize_query: bool = False,88max_retries: int = 3,89callback_manager: Optional[CallbackManager] = None,90query_transformer: Optional[FeedbackQueryTransformation] = None,91) -> None:92self._query_engine = query_engine93self._guideline_evaluator = guideline_evaluator94self.max_retries = max_retries95self.resynthesize_query = resynthesize_query96self.query_transformer = query_transformer or FeedbackQueryTransformation(97resynthesize_query=self.resynthesize_query98)99super().__init__(callback_manager)100
101def _get_prompt_modules(self) -> PromptMixinType:102"""Get prompt sub-modules."""103return {104"query_engine": self._query_engine,105"guideline_evalator": self._guideline_evaluator,106}107
108def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:109"""Answer a query."""110response = self._query_engine._query(query_bundle)111if self.max_retries <= 0:112return response113typed_response = (114response if isinstance(response, Response) else response.get_response()115)116query_str = query_bundle.query_str117eval = self._guideline_evaluator.evaluate_response(query_str, typed_response)118if eval.passing:119logger.debug("Evaluation returned True.")120return response121else:122logger.debug("Evaluation returned False.")123new_query_engine = RetryGuidelineQueryEngine(124self._query_engine,125self._guideline_evaluator,126self.resynthesize_query,127self.max_retries - 1,128self.callback_manager,129)130new_query = self.query_transformer.run(query_bundle, {"evaluation": eval})131logger.debug("New query: %s", new_query.query_str)132return new_query_engine.query(new_query)133
134async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:135"""Not supported."""136return self._query(query_bundle)137