llama-index

Форк
0
232 строки · 9.4 Кб
1
"""SQL wrapper around SQLDatabase in langchain."""
2
from typing import Any, Dict, Iterable, List, Optional, Tuple
3

4
from sqlalchemy import MetaData, create_engine, insert, inspect, text
5
from sqlalchemy.engine import Engine
6
from sqlalchemy.exc import OperationalError, ProgrammingError
7

8

9
class SQLDatabase:
10
    """SQL Database.
11

12
    This class provides a wrapper around the SQLAlchemy engine to interact with a SQL
13
    database.
14
    It provides methods to execute SQL commands, insert data into tables, and retrieve
15
    information about the database schema.
16
    It also supports optional features such as including or excluding specific tables,
17
    sampling rows for table info,
18
    including indexes in table info, and supporting views.
19

20
    Based on langchain SQLDatabase.
21
    https://github.com/langchain-ai/langchain/blob/e355606b1100097665207ca259de6dc548d44c78/libs/langchain/langchain/utilities/sql_database.py#L39
22

23
    Args:
24
        engine (Engine): The SQLAlchemy engine instance to use for database operations.
25
        schema (Optional[str]): The name of the schema to use, if any.
26
        metadata (Optional[MetaData]): The metadata instance to use, if any.
27
        ignore_tables (Optional[List[str]]): List of table names to ignore. If set,
28
            include_tables must be None.
29
        include_tables (Optional[List[str]]): List of table names to include. If set,
30
            ignore_tables must be None.
31
        sample_rows_in_table_info (int): The number of sample rows to include in table
32
            info.
33
        indexes_in_table_info (bool): Whether to include indexes in table info.
34
        custom_table_info (Optional[dict]): Custom table info to use.
35
        view_support (bool): Whether to support views.
36
        max_string_length (int): The maximum string length to use.
37

38
    """
39

40
    def __init__(
41
        self,
42
        engine: Engine,
43
        schema: Optional[str] = None,
44
        metadata: Optional[MetaData] = None,
45
        ignore_tables: Optional[List[str]] = None,
46
        include_tables: Optional[List[str]] = None,
47
        sample_rows_in_table_info: int = 3,
48
        indexes_in_table_info: bool = False,
49
        custom_table_info: Optional[dict] = None,
50
        view_support: bool = False,
51
        max_string_length: int = 300,
52
    ):
53
        """Create engine from database URI."""
54
        self._engine = engine
55
        self._schema = schema
56
        if include_tables and ignore_tables:
57
            raise ValueError("Cannot specify both include_tables and ignore_tables")
58

59
        self._inspector = inspect(self._engine)
60

61
        # including view support by adding the views as well as tables to the all
62
        # tables list if view_support is True
63
        self._all_tables = set(
64
            self._inspector.get_table_names(schema=schema)
65
            + (self._inspector.get_view_names(schema=schema) if view_support else [])
66
        )
67

68
        self._include_tables = set(include_tables) if include_tables else set()
69
        if self._include_tables:
70
            missing_tables = self._include_tables - self._all_tables
71
            if missing_tables:
72
                raise ValueError(
73
                    f"include_tables {missing_tables} not found in database"
74
                )
75
        self._ignore_tables = set(ignore_tables) if ignore_tables else set()
76
        if self._ignore_tables:
77
            missing_tables = self._ignore_tables - self._all_tables
78
            if missing_tables:
79
                raise ValueError(
80
                    f"ignore_tables {missing_tables} not found in database"
81
                )
82
        usable_tables = self.get_usable_table_names()
83
        self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
84

85
        if not isinstance(sample_rows_in_table_info, int):
86
            raise TypeError("sample_rows_in_table_info must be an integer")
87

88
        self._sample_rows_in_table_info = sample_rows_in_table_info
89
        self._indexes_in_table_info = indexes_in_table_info
90

91
        self._custom_table_info = custom_table_info
92
        if self._custom_table_info:
93
            if not isinstance(self._custom_table_info, dict):
94
                raise TypeError(
95
                    "table_info must be a dictionary with table names as keys and the "
96
                    "desired table info as values"
97
                )
98
            # only keep the tables that are also present in the database
99
            intersection = set(self._custom_table_info).intersection(self._all_tables)
100
            self._custom_table_info = {
101
                table: info
102
                for table, info in self._custom_table_info.items()
103
                if table in intersection
104
            }
105

106
        self._max_string_length = max_string_length
107

108
        self._metadata = metadata or MetaData()
109
        # including view support if view_support = true
110
        self._metadata.reflect(
111
            views=view_support,
112
            bind=self._engine,
113
            only=list(self._usable_tables),
114
            schema=self._schema,
115
        )
116

117
    @property
118
    def engine(self) -> Engine:
119
        """Return SQL Alchemy engine."""
120
        return self._engine
121

122
    @property
123
    def metadata_obj(self) -> MetaData:
124
        """Return SQL Alchemy metadata."""
125
        return self._metadata
126

127
    @classmethod
128
    def from_uri(
129
        cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
130
    ) -> "SQLDatabase":
131
        """Construct a SQLAlchemy engine from URI."""
132
        _engine_args = engine_args or {}
133
        return cls(create_engine(database_uri, **_engine_args), **kwargs)
134

135
    @property
136
    def dialect(self) -> str:
137
        """Return string representation of dialect to use."""
138
        return self._engine.dialect.name
139

140
    def get_usable_table_names(self) -> Iterable[str]:
141
        """Get names of tables available."""
142
        if self._include_tables:
143
            return sorted(self._include_tables)
144
        return sorted(self._all_tables - self._ignore_tables)
145

146
    def get_table_columns(self, table_name: str) -> List[Any]:
147
        """Get table columns."""
148
        return self._inspector.get_columns(table_name)
149

150
    def get_single_table_info(self, table_name: str) -> str:
151
        """Get table info for a single table."""
152
        # same logic as table_info, but with specific table names
153
        template = (
154
            "Table '{table_name}' has columns: {columns}, "
155
            "and foreign keys: {foreign_keys}."
156
        )
157
        columns = []
158
        for column in self._inspector.get_columns(table_name, schema=self._schema):
159
            if column.get("comment"):
160
                columns.append(
161
                    f"{column['name']} ({column['type']!s}): "
162
                    f"'{column.get('comment')}'"
163
                )
164
            else:
165
                columns.append(f"{column['name']} ({column['type']!s})")
166

167
        column_str = ", ".join(columns)
168
        foreign_keys = []
169
        for foreign_key in self._inspector.get_foreign_keys(
170
            table_name, schema=self._schema
171
        ):
172
            foreign_keys.append(
173
                f"{foreign_key['constrained_columns']} -> "
174
                f"{foreign_key['referred_table']}.{foreign_key['referred_columns']}"
175
            )
176
        foreign_key_str = ", ".join(foreign_keys)
177
        return template.format(
178
            table_name=table_name, columns=column_str, foreign_keys=foreign_key_str
179
        )
180

181
    def insert_into_table(self, table_name: str, data: dict) -> None:
182
        """Insert data into a table."""
183
        table = self._metadata.tables[table_name]
184
        stmt = insert(table).values(**data)
185
        with self._engine.begin() as connection:
186
            connection.execute(stmt)
187

188
    def truncate_word(self, content: Any, *, length: int, suffix: str = "...") -> str:
189
        """
190
        Truncate a string to a certain number of words, based on the max string
191
        length.
192
        """
193
        if not isinstance(content, str) or length <= 0:
194
            return content
195

196
        if len(content) <= length:
197
            return content
198

199
        return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
200

201
    def run_sql(self, command: str) -> Tuple[str, Dict]:
202
        """Execute a SQL statement and return a string representing the results.
203

204
        If the statement returns rows, a string of the results is returned.
205
        If the statement returns no rows, an empty string is returned.
206
        """
207
        with self._engine.begin() as connection:
208
            try:
209
                if self._schema:
210
                    command = command.replace("FROM ", f"FROM {self._schema}.")
211
                cursor = connection.execute(text(command))
212
            except (ProgrammingError, OperationalError) as exc:
213
                raise NotImplementedError(
214
                    f"Statement {command!r} is invalid SQL."
215
                ) from exc
216
            if cursor.returns_rows:
217
                result = cursor.fetchall()
218
                # truncate the results to the max string length
219
                # we can't use str(result) directly because it automatically truncates long strings
220
                truncated_results = []
221
                for row in result:
222
                    # truncate each column, then convert the row to a tuple
223
                    truncated_row = tuple(
224
                        self.truncate_word(column, length=self._max_string_length)
225
                        for column in row
226
                    )
227
                    truncated_results.append(truncated_row)
228
                return str(truncated_results), {
229
                    "result": truncated_results,
230
                    "col_keys": list(cursor.keys()),
231
                }
232
        return "", {}
233

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.