llama-index

Форк
0
213 строк · 7.8 Кб
1
import asyncio
2
from enum import Enum
3
from typing import Dict, List, Optional, Tuple, cast
4

5
from llama_index.legacy.async_utils import run_async_tasks
6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.constants import DEFAULT_SIMILARITY_TOP_K
8
from llama_index.legacy.llms.utils import LLMType, resolve_llm
9
from llama_index.legacy.prompts import PromptTemplate
10
from llama_index.legacy.prompts.mixin import PromptDictType
11
from llama_index.legacy.retrievers import BaseRetriever
12
from llama_index.legacy.schema import IndexNode, NodeWithScore, QueryBundle
13

14
QUERY_GEN_PROMPT = (
15
    "You are a helpful assistant that generates multiple search queries based on a "
16
    "single input query. Generate {num_queries} search queries, one on each line, "
17
    "related to the following input query:\n"
18
    "Query: {query}\n"
19
    "Queries:\n"
20
)
21

22

23
class FUSION_MODES(str, Enum):
24
    """Enum for different fusion modes."""
25

26
    RECIPROCAL_RANK = "reciprocal_rerank"  # apply reciprocal rank fusion
27
    SIMPLE = "simple"  # simple re-ordering of results based on original scores
28

29

30
class QueryFusionRetriever(BaseRetriever):
31
    def __init__(
32
        self,
33
        retrievers: List[BaseRetriever],
34
        llm: Optional[LLMType] = "default",
35
        query_gen_prompt: Optional[str] = None,
36
        mode: FUSION_MODES = FUSION_MODES.SIMPLE,
37
        similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
38
        num_queries: int = 4,
39
        use_async: bool = True,
40
        verbose: bool = False,
41
        callback_manager: Optional[CallbackManager] = None,
42
        objects: Optional[List[IndexNode]] = None,
43
        object_map: Optional[dict] = None,
44
    ) -> None:
45
        self.num_queries = num_queries
46
        self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT
47
        self.similarity_top_k = similarity_top_k
48
        self.mode = mode
49
        self.use_async = use_async
50

51
        self._retrievers = retrievers
52
        self._llm = resolve_llm(llm)
53
        super().__init__(
54
            callback_manager=callback_manager,
55
            object_map=object_map,
56
            objects=objects,
57
            verbose=verbose,
58
        )
59

60
    def _get_prompts(self) -> PromptDictType:
61
        """Get prompts."""
62
        return {"query_gen_prompt": PromptTemplate(self.query_gen_prompt)}
63

64
    def _update_prompts(self, prompts: PromptDictType) -> None:
65
        """Update prompts."""
66
        if "query_gen_prompt" in prompts:
67
            self.query_gen_prompt = cast(
68
                PromptTemplate, prompts["query_gen_prompt"]
69
            ).template
70

71
    def _get_queries(self, original_query: str) -> List[str]:
72
        prompt_str = self.query_gen_prompt.format(
73
            num_queries=self.num_queries - 1,
74
            query=original_query,
75
        )
76
        response = self._llm.complete(prompt_str)
77

78
        # assume LLM proper put each query on a newline
79
        queries = response.text.split("\n")
80
        if self._verbose:
81
            queries_str = "\n".join(queries)
82
            print(f"Generated queries:\n{queries_str}")
83
        return response.text.split("\n")
84

85
    def _reciprocal_rerank_fusion(
86
        self, results: Dict[Tuple[str, int], List[NodeWithScore]]
87
    ) -> List[NodeWithScore]:
88
        """Apply reciprocal rank fusion.
89

90
        The original paper uses k=60 for best results:
91
        https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
92
        """
93
        k = 60.0  # `k` is a parameter used to control the impact of outlier rankings.
94
        fused_scores = {}
95
        text_to_node = {}
96

97
        # compute reciprocal rank scores
98
        for nodes_with_scores in results.values():
99
            for rank, node_with_score in enumerate(
100
                sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
101
            ):
102
                text = node_with_score.node.get_content()
103
                text_to_node[text] = node_with_score
104
                if text not in fused_scores:
105
                    fused_scores[text] = 0.0
106
                fused_scores[text] += 1.0 / (rank + k)
107

108
        # sort results
109
        reranked_results = dict(
110
            sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
111
        )
112

113
        # adjust node scores
114
        reranked_nodes: List[NodeWithScore] = []
115
        for text, score in reranked_results.items():
116
            reranked_nodes.append(text_to_node[text])
117
            reranked_nodes[-1].score = score
118

119
        return reranked_nodes
120

121
    def _simple_fusion(
122
        self, results: Dict[Tuple[str, int], List[NodeWithScore]]
123
    ) -> List[NodeWithScore]:
124
        """Apply simple fusion."""
125
        # Use a dict to de-duplicate nodes
126
        all_nodes: Dict[str, NodeWithScore] = {}
127
        for nodes_with_scores in results.values():
128
            for node_with_score in nodes_with_scores:
129
                text = node_with_score.node.get_content()
130
                if text in all_nodes:
131
                    score = max(node_with_score.score, all_nodes[text].score)
132
                    all_nodes[text].score = score
133
                else:
134
                    all_nodes[text] = node_with_score
135

136
        return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
137

138
    def _run_nested_async_queries(
139
        self, queries: List[str]
140
    ) -> Dict[Tuple[str, int], List[NodeWithScore]]:
141
        tasks, task_queries = [], []
142
        for query in queries:
143
            for i, retriever in enumerate(self._retrievers):
144
                tasks.append(retriever.aretrieve(query))
145
                task_queries.append(query)
146

147
        task_results = run_async_tasks(tasks)
148

149
        results = {}
150
        for i, (query, query_result) in enumerate(zip(task_queries, task_results)):
151
            results[(query, i)] = query_result
152

153
        return results
154

155
    async def _run_async_queries(
156
        self, queries: List[str]
157
    ) -> Dict[Tuple[str, int], List[NodeWithScore]]:
158
        tasks, task_queries = [], []
159
        for query in queries:
160
            for i, retriever in enumerate(self._retrievers):
161
                tasks.append(retriever.aretrieve(query))
162
                task_queries.append(query)
163

164
        task_results = await asyncio.gather(*tasks)
165

166
        results = {}
167
        for i, (query, query_result) in enumerate(zip(task_queries, task_results)):
168
            results[(query, i)] = query_result
169

170
        return results
171

172
    def _run_sync_queries(
173
        self, queries: List[str]
174
    ) -> Dict[Tuple[str, int], List[NodeWithScore]]:
175
        results = {}
176
        for query in queries:
177
            for i, retriever in enumerate(self._retrievers):
178
                results[(query, i)] = retriever.retrieve(query)
179

180
        return results
181

182
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
183
        if self.num_queries > 1:
184
            queries = self._get_queries(query_bundle.query_str)
185
        else:
186
            queries = [query_bundle.query_str]
187

188
        if self.use_async:
189
            results = self._run_nested_async_queries(queries)
190
        else:
191
            results = self._run_sync_queries(queries)
192

193
        if self.mode == FUSION_MODES.RECIPROCAL_RANK:
194
            return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k]
195
        elif self.mode == FUSION_MODES.SIMPLE:
196
            return self._simple_fusion(results)[: self.similarity_top_k]
197
        else:
198
            raise ValueError(f"Invalid fusion mode: {self.mode}")
199

200
    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
201
        if self.num_queries > 1:
202
            queries = self._get_queries(query_bundle.query_str)
203
        else:
204
            queries = [query_bundle.query_str]
205

206
        results = await self._run_async_queries(queries)
207

208
        if self.mode == FUSION_MODES.RECIPROCAL_RANK:
209
            return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k]
210
        elif self.mode == FUSION_MODES.SIMPLE:
211
            return self._simple_fusion(results)[: self.similarity_top_k]
212
        else:
213
            raise ValueError(f"Invalid fusion mode: {self.mode}")
214

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.