llama-index

Форк
0
85 строк · 3.2 Кб
1
from typing import Optional, Dict
2
from llama_index.core.base.base_query_engine import BaseQueryEngine
3
from llama_index.core.callbacks.base import CallbackManager
4
from llama_index.core.schema import QueryBundle
5
from llama_index.core.base.response.schema import RESPONSE_TYPE
6
from llama_index.core.prompts.mixin import PromptMixinType
7
from llama_index.networks.schema.contributor import ContributorQueryResponse
8
from pydantic.v1 import BaseSettings, Field
9
import requests
10
import aiohttp
11

12

13
class ContributorClientSettings(BaseSettings):
14
    """Settings for contributor."""
15

16
    api_key: Optional[str] = Field(default=None, env="API_KEY")
17
    api_url: str = Field(..., env="API_URL")
18

19
    class Config:
20
        env_file = ".env", ".env.contributor.client"
21

22

23
class ContributorClient(BaseQueryEngine):
24
    """A remote QueryEngine exposed through a REST API."""
25

26
    def __init__(
27
        self,
28
        callback_manager: Optional[CallbackManager],
29
        config: ContributorClientSettings,
30
    ) -> None:
31
        self.config = config
32
        super().__init__(callback_manager)
33

34
    @classmethod
35
    def from_config_file(
36
        cls, env_file: str, callback_manager: Optional[CallbackManager] = None
37
    ) -> "ContributorClient":
38
        """Convenience constructor from a custom env file."""
39
        config = ContributorClientSettings(_env_file=env_file)
40
        return cls(callback_manager=callback_manager, config=config)
41

42
    def _query(
43
        self,
44
        query_bundle: QueryBundle,
45
        additional_data: Dict[str, str] = {},
46
        headers: Dict[str, str] = {},
47
    ) -> RESPONSE_TYPE:
48
        """Make a post request to submit a query to QueryEngine."""
49
        # headers = {"Authorization": f"Bearer {self.config.api_key}"}
50
        data = {"query": query_bundle.query_str, "api_key": self.config.api_key}
51
        data.update(additional_data)
52
        result = requests.post(
53
            self.config.api_url + "/api/query", json=data, headers=headers
54
        )
55
        try:
56
            contributor_response = ContributorQueryResponse.parse_obj(result.json())
57
        except Exception as e:
58
            raise ValueError("Failed to parse response") from e
59
        return contributor_response.to_response()
60

61
    async def _aquery(
62
        self,
63
        query_bundle: QueryBundle,
64
        api_token: Optional[str] = None,
65
        additional_data: Dict[str, str] = {},
66
        headers: Dict[str, str] = {},
67
    ) -> RESPONSE_TYPE:
68
        """Make a post request to submit a query to QueryEngine."""
69
        # headers = {"Authorization": f"Bearer {self.config.api_key}"}
70
        data = {"query": query_bundle.query_str, "api_token": api_token}
71
        data.update(additional_data)
72
        async with aiohttp.ClientSession() as session:
73
            async with session.post(
74
                self.config.api_url + "/api/query", json=data, headers=headers
75
            ) as resp:
76
                json_result = await resp.json()
77
            try:
78
                contributor_response = ContributorQueryResponse.parse_obj(json_result)
79
            except Exception as e:
80
                raise ValueError("Failed to parse response") from e
81
        return contributor_response.to_response()
82

83
    def _get_prompt_modules(self) -> PromptMixinType:
84
        """Get prompt sub-modules."""
85
        return {}
86

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.