llama-index
176 строк · 5.3 Кб
1"""Base object types."""
2
3import os
4import pickle
5from abc import abstractmethod
6from typing import Any, Dict, Generic, Optional, Sequence, TypeVar
7
8from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
9from llama_index.legacy.storage.storage_context import DEFAULT_PERSIST_DIR
10from llama_index.legacy.utils import concat_dirs
11
12DEFAULT_PERSIST_FNAME = "object_node_mapping.pickle"
13
14OT = TypeVar("OT")
15
16
17class BaseObjectNodeMapping(Generic[OT]):
18"""Base object node mapping."""
19
20@classmethod
21@abstractmethod
22def from_objects(
23cls, objs: Sequence[OT], *args: Any, **kwargs: Any
24) -> "BaseObjectNodeMapping":
25"""Initialize node mapping from a list of objects.
26
27Only needs to be specified if the node mapping
28needs to be initialized with a list of objects.
29
30"""
31
32def validate_object(self, obj: OT) -> None:
33"""Validate object."""
34
35def add_object(self, obj: OT) -> None:
36"""Add object.
37
38Only needs to be specified if the node mapping
39needs to be initialized with a list of objects.
40
41"""
42self.validate_object(obj)
43self._add_object(obj)
44
45@property
46@abstractmethod
47def obj_node_mapping(self) -> Dict[Any, Any]:
48"""The mapping data structure between node and object."""
49
50@abstractmethod
51def _add_object(self, obj: OT) -> None:
52"""Add object.
53
54Only needs to be specified if the node mapping
55needs to be initialized with a list of objects.
56
57"""
58
59@abstractmethod
60def to_node(self, obj: OT) -> TextNode:
61"""To node."""
62
63def to_nodes(self, objs: Sequence[OT]) -> Sequence[TextNode]:
64return [self.to_node(obj) for obj in objs]
65
66def from_node(self, node: BaseNode) -> OT:
67"""From node."""
68obj = self._from_node(node)
69self.validate_object(obj)
70return obj
71
72@abstractmethod
73def _from_node(self, node: BaseNode) -> OT:
74"""From node."""
75
76@abstractmethod
77def persist(
78self,
79persist_dir: str = DEFAULT_PERSIST_DIR,
80obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
81) -> None:
82"""Persist objs."""
83
84@classmethod
85def from_persist_dir(
86cls,
87persist_dir: str = DEFAULT_PERSIST_DIR,
88obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
89) -> "BaseObjectNodeMapping[OT]":
90"""Load from serialization."""
91obj_node_mapping = None
92errors = []
93for cls in BaseObjectNodeMapping.__subclasses__(): # type: ignore[misc]
94try:
95obj_node_mapping = cls.from_persist_dir(
96persist_dir=persist_dir,
97obj_node_mapping_fname=obj_node_mapping_fname,
98)
99break
100except (NotImplementedError, pickle.PickleError) as err:
101# raise unhandled exception otherwise
102errors.append(err)
103if obj_node_mapping:
104return obj_node_mapping
105else:
106raise Exception(errors)
107
108
109class SimpleObjectNodeMapping(BaseObjectNodeMapping[Any]):
110"""General node mapping that works for any obj.
111
112More specifically, any object with a meaningful string representation.
113
114"""
115
116def __init__(self, objs: Optional[Sequence[Any]] = None) -> None:
117objs = objs or []
118for obj in objs:
119self.validate_object(obj)
120self._objs = {hash(str(obj)): obj for obj in objs}
121
122@classmethod
123def from_objects(
124cls, objs: Sequence[Any], *args: Any, **kwargs: Any
125) -> "SimpleObjectNodeMapping":
126return cls(objs)
127
128@property
129def obj_node_mapping(self) -> Dict[int, Any]:
130return self._objs
131
132@obj_node_mapping.setter
133def obj_node_mapping(self, mapping: Dict[int, Any]) -> None:
134self._objs = mapping
135
136def _add_object(self, obj: Any) -> None:
137self._objs[hash(str(obj))] = obj
138
139def to_node(self, obj: Any) -> TextNode:
140return TextNode(text=str(obj))
141
142def _from_node(self, node: BaseNode) -> Any:
143return self._objs[hash(node.get_content(metadata_mode=MetadataMode.NONE))]
144
145def persist(
146self,
147persist_dir: str = DEFAULT_PERSIST_DIR,
148obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
149) -> None:
150"""Persist object node mapping.
151
152NOTE: This may fail depending on whether the object types are
153pickle-able.
154"""
155if not os.path.exists(persist_dir):
156os.makedirs(persist_dir)
157obj_node_mapping_path = concat_dirs(persist_dir, obj_node_mapping_fname)
158try:
159with open(obj_node_mapping_path, "wb") as f:
160pickle.dump(self, f)
161except pickle.PickleError as err:
162raise ValueError("Objs is not pickleable") from err
163
164@classmethod
165def from_persist_dir(
166cls,
167persist_dir: str = DEFAULT_PERSIST_DIR,
168obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
169) -> "SimpleObjectNodeMapping":
170obj_node_mapping_path = concat_dirs(persist_dir, obj_node_mapping_fname)
171try:
172with open(obj_node_mapping_path, "rb") as f:
173simple_object_node_mapping = pickle.load(f)
174except pickle.PickleError as err:
175raise ValueError("Objs cannot be loaded.") from err
176return simple_object_node_mapping
177