llama-index

Форк
0
94 строки · 3.0 Кб
1
"""Table node mapping."""
2

3
from typing import Any, Dict, Optional, Sequence
4

5
from llama_index.legacy.bridge.pydantic import BaseModel
6
from llama_index.legacy.objects.base_node_mapping import (
7
    DEFAULT_PERSIST_DIR,
8
    DEFAULT_PERSIST_FNAME,
9
    BaseObjectNodeMapping,
10
)
11
from llama_index.legacy.schema import BaseNode, TextNode
12
from llama_index.legacy.utilities.sql_wrapper import SQLDatabase
13

14

15
class SQLTableSchema(BaseModel):
16
    """Lightweight representation of a SQL table."""
17

18
    table_name: str
19
    context_str: Optional[str] = None
20

21

22
class SQLTableNodeMapping(BaseObjectNodeMapping[SQLTableSchema]):
23
    """SQL Table node mapping."""
24

25
    def __init__(self, sql_database: SQLDatabase) -> None:
26
        self._sql_database = sql_database
27

28
    @classmethod
29
    def from_objects(
30
        cls,
31
        objs: Sequence[SQLTableSchema],
32
        *args: Any,
33
        sql_database: Optional[SQLDatabase] = None,
34
        **kwargs: Any,
35
    ) -> "BaseObjectNodeMapping":
36
        """Initialize node mapping."""
37
        if sql_database is None:
38
            raise ValueError("Must provide sql_database")
39
        # ignore objs, since we are building from sql_database
40
        return cls(sql_database)
41

42
    def _add_object(self, obj: SQLTableSchema) -> None:
43
        raise NotImplementedError
44

45
    def to_node(self, obj: SQLTableSchema) -> TextNode:
46
        """To node."""
47
        # taken from existing schema logic
48
        table_text = (
49
            f"Schema of table {obj.table_name}:\n"
50
            f"{self._sql_database.get_single_table_info(obj.table_name)}\n"
51
        )
52

53
        metadata = {"name": obj.table_name}
54

55
        if obj.context_str is not None:
56
            table_text += f"Context of table {obj.table_name}:\n"
57
            table_text += obj.context_str
58
            metadata["context"] = obj.context_str
59

60
        return TextNode(
61
            text=table_text,
62
            metadata=metadata,
63
            excluded_embed_metadata_keys=["name", "context"],
64
            excluded_llm_metadata_keys=["name", "context"],
65
        )
66

67
    def _from_node(self, node: BaseNode) -> SQLTableSchema:
68
        """From node."""
69
        if node.metadata is None:
70
            raise ValueError("Metadata must be set")
71
        return SQLTableSchema(
72
            table_name=node.metadata["name"], context_str=node.metadata.get("context")
73
        )
74

75
    @property
76
    def obj_node_mapping(self) -> Dict[int, Any]:
77
        """The mapping data structure between node and object."""
78
        raise NotImplementedError("Subclasses should implement this!")
79

80
    def persist(
81
        self, persist_dir: str = ..., obj_node_mapping_fname: str = ...
82
    ) -> None:
83
        """Persist objs."""
84
        raise NotImplementedError("Subclasses should implement this!")
85

86
    @classmethod
87
    def from_persist_dir(
88
        cls,
89
        persist_dir: str = DEFAULT_PERSIST_DIR,
90
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
91
    ) -> "SQLTableNodeMapping":
92
        raise NotImplementedError(
93
            "This object node mapping does not support persist method."
94
        )
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.