llama-index

Форк
0
188 строк · 7.0 Кб
1
"""Experiment with different indices, models, and more."""
2

3
from __future__ import annotations
4

5
import time
6
from typing import Any, Dict, List, Type
7

8
import pandas as pd
9

10
from llama_index.legacy.callbacks import CallbackManager, TokenCountingHandler
11
from llama_index.legacy.indices.base import BaseIndex
12
from llama_index.legacy.indices.list.base import ListRetrieverMode, SummaryIndex
13
from llama_index.legacy.indices.tree.base import TreeIndex, TreeRetrieverMode
14
from llama_index.legacy.indices.vector_store import VectorStoreIndex
15
from llama_index.legacy.llm_predictor.base import LLMPredictor
16
from llama_index.legacy.schema import Document
17
from llama_index.legacy.utils import get_color_mapping, print_text
18

19
DEFAULT_INDEX_CLASSES: List[Type[BaseIndex]] = [
20
    VectorStoreIndex,
21
    TreeIndex,
22
    SummaryIndex,
23
]
24

25
INDEX_SPECIFIC_QUERY_MODES_TYPE = Dict[Type[BaseIndex], List[str]]
26

27
DEFAULT_MODES: INDEX_SPECIFIC_QUERY_MODES_TYPE = {
28
    TreeIndex: [e.value for e in TreeRetrieverMode],
29
    SummaryIndex: [e.value for e in ListRetrieverMode],
30
    VectorStoreIndex: ["default"],
31
}
32

33

34
class Playground:
35
    """Experiment with indices, models, embeddings, retriever_modes, and more."""
36

37
    def __init__(
38
        self,
39
        indices: List[BaseIndex],
40
        retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE = DEFAULT_MODES,
41
    ):
42
        """Initialize with indices to experiment with.
43

44
        Args:
45
            indices: A list of BaseIndex's to experiment with
46
            retriever_modes: A list of retriever_modes that specify which nodes are
47
                chosen from the index when a query is made. A full list of
48
                retriever_modes available to each index can be found here:
49
                https://docs.llamaindex.ai/en/stable/module_guides/querying/retriever/retriever_modes.html
50
        """
51
        self._validate_indices(indices)
52
        self._indices = indices
53
        self._validate_modes(retriever_modes)
54
        self._retriever_modes = retriever_modes
55

56
        index_range = [str(i) for i in range(len(indices))]
57
        self.index_colors = get_color_mapping(index_range)
58

59
    @classmethod
60
    def from_docs(
61
        cls,
62
        documents: List[Document],
63
        index_classes: List[Type[BaseIndex]] = DEFAULT_INDEX_CLASSES,
64
        retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE = DEFAULT_MODES,
65
        **kwargs: Any,
66
    ) -> Playground:
67
        """Initialize with Documents using the default list of indices.
68

69
        Args:
70
            documents: A List of Documents to experiment with.
71
        """
72
        if len(documents) == 0:
73
            raise ValueError(
74
                "Playground must be initialized with a nonempty list of Documents."
75
            )
76

77
        indices = [
78
            index_class.from_documents(documents, **kwargs)
79
            for index_class in index_classes
80
        ]
81
        return cls(indices, retriever_modes)
82

83
    def _validate_indices(self, indices: List[BaseIndex]) -> None:
84
        """Validate a list of indices."""
85
        if len(indices) == 0:
86
            raise ValueError("Playground must have a non-empty list of indices.")
87
        for index in indices:
88
            if not isinstance(index, BaseIndex):
89
                raise ValueError(
90
                    "Every index in Playground should be an instance of BaseIndex."
91
                )
92

93
    @property
94
    def indices(self) -> List[BaseIndex]:
95
        """Get Playground's indices."""
96
        return self._indices
97

98
    @indices.setter
99
    def indices(self, indices: List[BaseIndex]) -> None:
100
        """Set Playground's indices."""
101
        self._validate_indices(indices)
102
        self._indices = indices
103

104
    def _validate_modes(self, retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE) -> None:
105
        """Validate a list of retriever_modes."""
106
        if len(retriever_modes) == 0:
107
            raise ValueError(
108
                "Playground must have a nonzero number of retriever_modes."
109
                "Initialize without the `retriever_modes` "
110
                "argument to use the default list."
111
            )
112

113
    @property
114
    def retriever_modes(self) -> dict:
115
        """Get Playground's indices."""
116
        return self._retriever_modes
117

118
    @retriever_modes.setter
119
    def retriever_modes(self, retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE) -> None:
120
        """Set Playground's indices."""
121
        self._validate_modes(retriever_modes)
122
        self._retriever_modes = retriever_modes
123

124
    def compare(
125
        self, query_text: str, to_pandas: bool | None = True
126
    ) -> pd.DataFrame | List[Dict[str, Any]]:
127
        """Compare index outputs on an input query.
128

129
        Args:
130
            query_text (str): Query to run all indices on.
131
            to_pandas (Optional[bool]): Return results in a pandas dataframe.
132
                True by default.
133

134
        Returns:
135
            The output of each index along with other data, such as the time it took to
136
            compute. Results are stored in a Pandas Dataframe or a list of Dicts.
137
        """
138
        print(f"\033[1mQuery:\033[0m\n{query_text}\n")
139
        result = []
140
        for i, index in enumerate(self._indices):
141
            for retriever_mode in self._retriever_modes[type(index)]:
142
                start_time = time.time()
143

144
                index_name = type(index).__name__
145
                print_text(
146
                    f"\033[1m{index_name}\033[0m, retriever mode = {retriever_mode}",
147
                    end="\n",
148
                )
149

150
                # insert token counter into service context
151
                service_context = index.service_context
152
                token_counter = TokenCountingHandler()
153
                callback_manager = CallbackManager([token_counter])
154
                if isinstance(service_context.llm_predictor, LLMPredictor):
155
                    service_context.llm_predictor.llm.callback_manager = (
156
                        callback_manager
157
                    )
158
                    service_context.embed_model.callback_manager = callback_manager
159

160
                try:
161
                    query_engine = index.as_query_engine(
162
                        retriever_mode=retriever_mode, service_context=service_context
163
                    )
164
                except ValueError:
165
                    continue
166

167
                output = query_engine.query(query_text)
168
                print_text(str(output), color=self.index_colors[str(i)], end="\n\n")
169

170
                duration = time.time() - start_time
171

172
                result.append(
173
                    {
174
                        "Index": index_name,
175
                        "Retriever Mode": retriever_mode,
176
                        "Output": str(output),
177
                        "Duration": duration,
178
                        "Prompt Tokens": token_counter.prompt_llm_token_count,
179
                        "Completion Tokens": token_counter.completion_llm_token_count,
180
                        "Embed Tokens": token_counter.total_embedding_token_count,
181
                    }
182
                )
183
        print(f"\nRan {len(result)} combinations in total.")
184

185
        if to_pandas:
186
            return pd.DataFrame(result)
187
        else:
188
            return result
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.