llama-index

Форк
0
78 строк · 2.9 Кб
1
"""Custom query engine."""
2

3
from abc import abstractmethod
4
from typing import Union
5

6
from llama_index.legacy.bridge.pydantic import BaseModel, Field
7
from llama_index.legacy.callbacks.base import CallbackManager
8
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
9
from llama_index.legacy.core.response.schema import RESPONSE_TYPE, Response
10
from llama_index.legacy.prompts.mixin import PromptMixinType
11
from llama_index.legacy.schema import QueryBundle, QueryType
12

13
STR_OR_RESPONSE_TYPE = Union[RESPONSE_TYPE, str]
14

15

16
class CustomQueryEngine(BaseModel, BaseQueryEngine):
17
    """Custom query engine.
18

19
    Subclasses can define additional attributes as Pydantic fields.
20
    Subclasses must implement the `custom_query` method, which takes a query string
21
    and returns either a Response object or a string as output.
22

23
    They can optionally implement the `acustom_query` method for async support.
24

25
    """
26

27
    callback_manager: CallbackManager = Field(
28
        default_factory=lambda: CallbackManager([]), exclude=True
29
    )
30

31
    def _get_prompt_modules(self) -> PromptMixinType:
32
        """Get prompt sub-modules."""
33
        return {}
34

35
    class Config:
36
        arbitrary_types_allowed = True
37

38
    def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
39
        with self.callback_manager.as_trace("query"):
40
            # if query bundle, just run the query
41
            if isinstance(str_or_query_bundle, QueryBundle):
42
                query_str = str_or_query_bundle.query_str
43
            else:
44
                query_str = str_or_query_bundle
45
            raw_response = self.custom_query(query_str)
46
            return (
47
                Response(raw_response)
48
                if isinstance(raw_response, str)
49
                else raw_response
50
            )
51

52
    async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
53
        with self.callback_manager.as_trace("query"):
54
            if isinstance(str_or_query_bundle, QueryBundle):
55
                query_str = str_or_query_bundle.query_str
56
            else:
57
                query_str = str_or_query_bundle
58
            raw_response = await self.acustom_query(query_str)
59
            return (
60
                Response(raw_response)
61
                if isinstance(raw_response, str)
62
                else raw_response
63
            )
64

65
    @abstractmethod
66
    def custom_query(self, query_str: str) -> STR_OR_RESPONSE_TYPE:
67
        """Run a custom query."""
68

69
    async def acustom_query(self, query_str: str) -> STR_OR_RESPONSE_TYPE:
70
        """Run a custom query asynchronously."""
71
        # by default, just run the synchronous version
72
        return self.custom_query(query_str)
73

74
    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
75
        raise NotImplementedError("This query engine does not support _query.")
76

77
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
78
        raise NotImplementedError("This query engine does not support _aquery.")
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.