llama-index

Форк
0
175 строк · 5.3 Кб
1
"""MyScale reader."""
2

3
import logging
4
from typing import Any, List, Optional
5

6
from llama_index.legacy.readers.base import BaseReader
7
from llama_index.legacy.schema import Document
8

9
logger = logging.getLogger(__name__)
10

11

12
def escape_str(value: str) -> str:
13
    BS = "\\"
14
    must_escape = (BS, "'")
15
    return (
16
        "".join(f"{BS}{c}" if c in must_escape else c for c in value) if value else ""
17
    )
18

19

20
def format_list_to_string(lst: List) -> str:
21
    return "[" + ",".join(str(item) for item in lst) + "]"
22

23

24
class MyScaleSettings:
25
    """MyScale Client Configuration.
26

27
    Attribute:
28
        table (str) : Table name to operate on.
29
        database (str) : Database name to find the table.
30
        index_type (str): index type string
31
        metric (str) : metric type to compute distance
32
        batch_size (int): the size of documents to insert
33
        index_params (dict, optional): index build parameter
34
        search_params (dict, optional): index search parameters for MyScale query
35
    """
36

37
    def __init__(
38
        self,
39
        table: str,
40
        database: str,
41
        index_type: str,
42
        metric: str,
43
        batch_size: int,
44
        index_params: Optional[dict] = None,
45
        search_params: Optional[dict] = None,
46
        **kwargs: Any,
47
    ) -> None:
48
        self.table = table
49
        self.database = database
50
        self.index_type = index_type
51
        self.metric = metric
52
        self.batch_size = batch_size
53
        self.index_params = index_params
54
        self.search_params = search_params
55

56
    def build_query_statement(
57
        self,
58
        query_embed: List[float],
59
        where_str: Optional[str] = None,
60
        limit: Optional[int] = None,
61
    ) -> str:
62
        query_embed_str = format_list_to_string(query_embed)
63
        where_str = f"PREWHERE {where_str}" if where_str else ""
64
        order = "DESC" if self.metric.lower() == "ip" else "ASC"
65

66
        search_params_str = (
67
            (
68
                "("
69
                + ",".join([f"'{k}={v}'" for k, v in self.search_params.items()])
70
                + ")"
71
            )
72
            if self.search_params
73
            else ""
74
        )
75

76
        return f"""
77
            SELECT id, doc_id, text, node_info, metadata,
78
            distance{search_params_str}(vector, {query_embed_str}) AS dist
79
            FROM {self.database}.{self.table} {where_str}
80
            ORDER BY dist {order}
81
            LIMIT {limit}
82
            """
83

84

85
class MyScaleReader(BaseReader):
86
    """MyScale reader.
87

88
    Args:
89
        myscale_host (str) : An URL to connect to MyScale backend.
90
        username (str) : Usernamed to login.
91
        password (str) : Password to login.
92
        myscale_port (int) : URL port to connect with HTTP. Defaults to 8443.
93
        database (str) : Database name to find the table. Defaults to 'default'.
94
        table (str) : Table name to operate on. Defaults to 'vector_table'.
95
        index_type (str): index type string. Default to "IVFLAT"
96
        metric (str) : Metric to compute distance, supported are ('l2', 'cosine', 'ip').
97
            Defaults to 'cosine'
98
        batch_size (int, optional): the size of documents to insert. Defaults to 32.
99
        index_params (dict, optional): The index parameters for MyScale.
100
            Defaults to None.
101
        search_params (dict, optional): The search parameters for a MyScale query.
102
            Defaults to None.
103

104
    """
105

106
    def __init__(
107
        self,
108
        myscale_host: str,
109
        username: str,
110
        password: str,
111
        myscale_port: Optional[int] = 8443,
112
        database: str = "default",
113
        table: str = "llama_index",
114
        index_type: str = "IVFLAT",
115
        metric: str = "cosine",
116
        batch_size: int = 32,
117
        index_params: Optional[dict] = None,
118
        search_params: Optional[dict] = None,
119
        **kwargs: Any,
120
    ) -> None:
121
        """Initialize params."""
122
        import_err_msg = """
123
            `clickhouse_connect` package not found,
124
            please run `pip install clickhouse-connect`
125
        """
126
        try:
127
            import clickhouse_connect
128
        except ImportError:
129
            raise ImportError(import_err_msg)
130

131
        self.client = clickhouse_connect.get_client(
132
            host=myscale_host,
133
            port=myscale_port,
134
            username=username,
135
            password=password,
136
        )
137

138
        self.config = MyScaleSettings(
139
            table=table,
140
            database=database,
141
            index_type=index_type,
142
            metric=metric,
143
            batch_size=batch_size,
144
            index_params=index_params,
145
            search_params=search_params,
146
            **kwargs,
147
        )
148

149
    def load_data(
150
        self,
151
        query_vector: List[float],
152
        where_str: Optional[str] = None,
153
        limit: int = 10,
154
    ) -> List[Document]:
155
        """Load data from MyScale.
156

157
        Args:
158
            query_vector (List[float]): Query vector.
159
            where_str (Optional[str], optional): where condition string.
160
                Defaults to None.
161
            limit (int): Number of results to return.
162

163
        Returns:
164
            List[Document]: A list of documents.
165
        """
166
        query_statement = self.config.build_query_statement(
167
            query_embed=query_vector,
168
            where_str=where_str,
169
            limit=limit,
170
        )
171

172
        return [
173
            Document(id_=r["doc_id"], text=r["text"], metadata=r["metadata"])
174
            for r in self.client.query(query_statement).named_results()
175
        ]
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.