llama-index

Форк
0
182 строки · 7.0 Кб
1
# Auto Merging Retriever
2

3
import logging
4
from collections import defaultdict
5
from typing import Dict, List, Optional, Tuple, cast
6

7
from llama_index.legacy.callbacks.base import CallbackManager
8
from llama_index.legacy.core.base_retriever import BaseRetriever
9
from llama_index.legacy.indices.query.schema import QueryBundle
10
from llama_index.legacy.indices.utils import truncate_text
11
from llama_index.legacy.indices.vector_store.retrievers.retriever import (
12
    VectorIndexRetriever,
13
)
14
from llama_index.legacy.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
15
from llama_index.legacy.storage.storage_context import StorageContext
16

17
logger = logging.getLogger(__name__)
18

19

20
class AutoMergingRetriever(BaseRetriever):
21
    """This retriever will try to merge context into parent context.
22

23
    The retriever first retrieves chunks from a vector store.
24
    Then, it will try to merge the chunks into a single context.
25

26
    """
27

28
    def __init__(
29
        self,
30
        vector_retriever: VectorIndexRetriever,
31
        storage_context: StorageContext,
32
        simple_ratio_thresh: float = 0.5,
33
        verbose: bool = False,
34
        callback_manager: Optional[CallbackManager] = None,
35
        object_map: Optional[dict] = None,
36
        objects: Optional[List[IndexNode]] = None,
37
    ) -> None:
38
        """Init params."""
39
        self._vector_retriever = vector_retriever
40
        self._storage_context = storage_context
41
        self._simple_ratio_thresh = simple_ratio_thresh
42
        super().__init__(
43
            callback_manager=callback_manager,
44
            object_map=object_map,
45
            objects=objects,
46
            verbose=verbose,
47
        )
48

49
    def _get_parents_and_merge(
50
        self, nodes: List[NodeWithScore]
51
    ) -> Tuple[List[NodeWithScore], bool]:
52
        """Get parents and merge nodes."""
53
        # retrieve all parent nodes
54
        parent_nodes: Dict[str, BaseNode] = {}
55
        parent_cur_children_dict: Dict[str, List[NodeWithScore]] = defaultdict(list)
56
        for node in nodes:
57
            if node.node.parent_node is None:
58
                continue
59
            parent_node_info = node.node.parent_node
60

61
            # Fetch actual parent node if doesn't exist in `parent_nodes` cache yet
62
            parent_node_id = parent_node_info.node_id
63
            if parent_node_id not in parent_nodes:
64
                parent_node = self._storage_context.docstore.get_document(
65
                    parent_node_id
66
                )
67
                parent_nodes[parent_node_id] = cast(BaseNode, parent_node)
68

69
            # add reference to child from parent
70
            parent_cur_children_dict[parent_node_id].append(node)
71

72
        # compute ratios and "merge" nodes
73
        # merging: delete some children nodes, add some parent nodes
74
        node_ids_to_delete = set()
75
        nodes_to_add: Dict[str, BaseNode] = {}
76
        for parent_node_id, parent_node in parent_nodes.items():
77
            parent_child_nodes = parent_node.child_nodes
78
            parent_num_children = len(parent_child_nodes) if parent_child_nodes else 1
79
            parent_cur_children = parent_cur_children_dict[parent_node_id]
80
            ratio = len(parent_cur_children) / parent_num_children
81

82
            # if ratio is high enough, merge
83
            if ratio > self._simple_ratio_thresh:
84
                node_ids_to_delete.update(
85
                    set({n.node.node_id for n in parent_cur_children})
86
                )
87

88
                parent_node_text = truncate_text(parent_node.text, 100)
89
                info_str = (
90
                    f"> Merging {len(parent_cur_children)} nodes into parent node.\n"
91
                    f"> Parent node id: {parent_node_id}.\n"
92
                    f"> Parent node text: {parent_node_text}\n"
93
                )
94
                logger.info(info_str)
95
                if self._verbose:
96
                    print(info_str)
97

98
                # add parent node
99
                # can try averaging score across embeddings for now
100

101
                avg_score = sum(
102
                    [n.get_score() or 0.0 for n in parent_cur_children]
103
                ) / len(parent_cur_children)
104
                parent_node_with_score = NodeWithScore(
105
                    node=parent_node, score=avg_score
106
                )
107
                nodes_to_add[parent_node_id] = parent_node_with_score
108

109
        # delete old child nodes, add new parent nodes
110
        new_nodes = [n for n in nodes if n.node.node_id not in node_ids_to_delete]
111
        # add parent nodes
112
        new_nodes.extend(list(nodes_to_add.values()))
113

114
        is_changed = len(node_ids_to_delete) > 0
115

116
        return new_nodes, is_changed
117

118
    def _fill_in_nodes(
119
        self, nodes: List[NodeWithScore]
120
    ) -> Tuple[List[NodeWithScore], bool]:
121
        """Fill in nodes."""
122
        new_nodes = []
123
        is_changed = False
124
        for idx, node in enumerate(nodes):
125
            new_nodes.append(node)
126
            if idx >= len(nodes) - 1:
127
                continue
128

129
            cur_node = cast(BaseNode, node.node)
130
            # if there's a node in the middle, add that to the queue
131
            if (
132
                cur_node.next_node is not None
133
                and cur_node.next_node == nodes[idx + 1].node.prev_node
134
            ):
135
                is_changed = True
136
                next_node = self._storage_context.docstore.get_document(
137
                    cur_node.next_node.node_id
138
                )
139
                next_node = cast(BaseNode, next_node)
140

141
                next_node_text = truncate_text(next_node.get_text(), 100)
142
                info_str = (
143
                    f"> Filling in node. Node id: {cur_node.next_node.node_id}"
144
                    f"> Node text: {next_node_text}\n"
145
                )
146
                logger.info(info_str)
147
                if self._verbose:
148
                    print(info_str)
149

150
                # set score to be average of current node and next node
151
                avg_score = (node.get_score() + nodes[idx + 1].get_score()) / 2
152
                new_nodes.append(NodeWithScore(node=next_node, score=avg_score))
153
        return new_nodes, is_changed
154

155
    def _try_merging(
156
        self, nodes: List[NodeWithScore]
157
    ) -> Tuple[List[NodeWithScore], bool]:
158
        """Try different ways to merge nodes."""
159
        # first try filling in nodes
160
        nodes, is_changed_0 = self._fill_in_nodes(nodes)
161
        # then try merging nodes
162
        nodes, is_changed_1 = self._get_parents_and_merge(nodes)
163
        return nodes, is_changed_0 or is_changed_1
164

165
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
166
        """Retrieve nodes given query.
167

168
        Implemented by the user.
169

170
        """
171
        initial_nodes = self._vector_retriever.retrieve(query_bundle)
172

173
        cur_nodes, is_changed = self._try_merging(initial_nodes)
174
        # cur_nodes, is_changed = self._get_parents_and_merge(initial_nodes)
175
        while is_changed:
176
            cur_nodes, is_changed = self._try_merging(cur_nodes)
177
            # cur_nodes, is_changed = self._get_parents_and_merge(cur_nodes)
178

179
        # sort by similarity
180
        cur_nodes.sort(key=lambda x: x.get_score(), reverse=True)
181

182
        return cur_nodes
183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.