llama-index

Форк
0
176 строк · 5.3 Кб
1
"""Base object types."""
2

3
import os
4
import pickle
5
from abc import abstractmethod
6
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar
7

8
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
9
from llama_index.legacy.storage.storage_context import DEFAULT_PERSIST_DIR
10
from llama_index.legacy.utils import concat_dirs
11

12
DEFAULT_PERSIST_FNAME = "object_node_mapping.pickle"
13

14
OT = TypeVar("OT")
15

16

17
class BaseObjectNodeMapping(Generic[OT]):
18
    """Base object node mapping."""
19

20
    @classmethod
21
    @abstractmethod
22
    def from_objects(
23
        cls, objs: Sequence[OT], *args: Any, **kwargs: Any
24
    ) -> "BaseObjectNodeMapping":
25
        """Initialize node mapping from a list of objects.
26

27
        Only needs to be specified if the node mapping
28
        needs to be initialized with a list of objects.
29

30
        """
31

32
    def validate_object(self, obj: OT) -> None:
33
        """Validate object."""
34

35
    def add_object(self, obj: OT) -> None:
36
        """Add object.
37

38
        Only needs to be specified if the node mapping
39
        needs to be initialized with a list of objects.
40

41
        """
42
        self.validate_object(obj)
43
        self._add_object(obj)
44

45
    @property
46
    @abstractmethod
47
    def obj_node_mapping(self) -> Dict[Any, Any]:
48
        """The mapping data structure between node and object."""
49

50
    @abstractmethod
51
    def _add_object(self, obj: OT) -> None:
52
        """Add object.
53

54
        Only needs to be specified if the node mapping
55
        needs to be initialized with a list of objects.
56

57
        """
58

59
    @abstractmethod
60
    def to_node(self, obj: OT) -> TextNode:
61
        """To node."""
62

63
    def to_nodes(self, objs: Sequence[OT]) -> Sequence[TextNode]:
64
        return [self.to_node(obj) for obj in objs]
65

66
    def from_node(self, node: BaseNode) -> OT:
67
        """From node."""
68
        obj = self._from_node(node)
69
        self.validate_object(obj)
70
        return obj
71

72
    @abstractmethod
73
    def _from_node(self, node: BaseNode) -> OT:
74
        """From node."""
75

76
    @abstractmethod
77
    def persist(
78
        self,
79
        persist_dir: str = DEFAULT_PERSIST_DIR,
80
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
81
    ) -> None:
82
        """Persist objs."""
83

84
    @classmethod
85
    def from_persist_dir(
86
        cls,
87
        persist_dir: str = DEFAULT_PERSIST_DIR,
88
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
89
    ) -> "BaseObjectNodeMapping[OT]":
90
        """Load from serialization."""
91
        obj_node_mapping = None
92
        errors = []
93
        for cls in BaseObjectNodeMapping.__subclasses__():  # type: ignore[misc]
94
            try:
95
                obj_node_mapping = cls.from_persist_dir(
96
                    persist_dir=persist_dir,
97
                    obj_node_mapping_fname=obj_node_mapping_fname,
98
                )
99
                break
100
            except (NotImplementedError, pickle.PickleError) as err:
101
                # raise unhandled exception otherwise
102
                errors.append(err)
103
        if obj_node_mapping:
104
            return obj_node_mapping
105
        else:
106
            raise Exception(errors)
107

108

109
class SimpleObjectNodeMapping(BaseObjectNodeMapping[Any]):
110
    """General node mapping that works for any obj.
111

112
    More specifically, any object with a meaningful string representation.
113

114
    """
115

116
    def __init__(self, objs: Optional[Sequence[Any]] = None) -> None:
117
        objs = objs or []
118
        for obj in objs:
119
            self.validate_object(obj)
120
        self._objs = {hash(str(obj)): obj for obj in objs}
121

122
    @classmethod
123
    def from_objects(
124
        cls, objs: Sequence[Any], *args: Any, **kwargs: Any
125
    ) -> "SimpleObjectNodeMapping":
126
        return cls(objs)
127

128
    @property
129
    def obj_node_mapping(self) -> Dict[int, Any]:
130
        return self._objs
131

132
    @obj_node_mapping.setter
133
    def obj_node_mapping(self, mapping: Dict[int, Any]) -> None:
134
        self._objs = mapping
135

136
    def _add_object(self, obj: Any) -> None:
137
        self._objs[hash(str(obj))] = obj
138

139
    def to_node(self, obj: Any) -> TextNode:
140
        return TextNode(text=str(obj))
141

142
    def _from_node(self, node: BaseNode) -> Any:
143
        return self._objs[hash(node.get_content(metadata_mode=MetadataMode.NONE))]
144

145
    def persist(
146
        self,
147
        persist_dir: str = DEFAULT_PERSIST_DIR,
148
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
149
    ) -> None:
150
        """Persist object node mapping.
151

152
        NOTE: This may fail depending on whether the object types are
153
        pickle-able.
154
        """
155
        if not os.path.exists(persist_dir):
156
            os.makedirs(persist_dir)
157
        obj_node_mapping_path = concat_dirs(persist_dir, obj_node_mapping_fname)
158
        try:
159
            with open(obj_node_mapping_path, "wb") as f:
160
                pickle.dump(self, f)
161
        except pickle.PickleError as err:
162
            raise ValueError("Objs is not pickleable") from err
163

164
    @classmethod
165
    def from_persist_dir(
166
        cls,
167
        persist_dir: str = DEFAULT_PERSIST_DIR,
168
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
169
    ) -> "SimpleObjectNodeMapping":
170
        obj_node_mapping_path = concat_dirs(persist_dir, obj_node_mapping_fname)
171
        try:
172
            with open(obj_node_mapping_path, "rb") as f:
173
                simple_object_node_mapping = pickle.load(f)
174
        except pickle.PickleError as err:
175
            raise ValueError("Objs cannot be loaded.") from err
176
        return simple_object_node_mapping
177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.