llama-index

Форк
0
1
import random
2
import re
3
import signal
4
from collections import defaultdict
5
from contextlib import contextmanager
6
from typing import Any, Dict, List, Optional, Set, Tuple
7

8
from llama_index.legacy.program.predefined.evaporate.prompts import (
9
    DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL,
10
    DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
11
    FN_GENERATION_PROMPT,
12
    SCHEMA_ID_PROMPT,
13
    FnGeneratePrompt,
14
    SchemaIDPrompt,
15
)
16
from llama_index.legacy.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle
17
from llama_index.legacy.service_context import ServiceContext
18

19

20
class TimeoutException(Exception):
21
    pass
22

23

24
@contextmanager
25
def time_limit(seconds: int) -> Any:
26
    """Time limit context manager.
27

28
    NOTE: copied from https://github.com/HazyResearch/evaporate.
29

30
    """
31

32
    def signal_handler(signum: Any, frame: Any) -> Any:
33
        raise TimeoutException("Timed out!")
34

35
    signal.signal(signal.SIGALRM, signal_handler)
36
    signal.alarm(seconds)
37
    try:
38
        yield
39
    finally:
40
        signal.alarm(0)
41

42

43
def get_function_field_from_attribute(attribute: str) -> str:
44
    """Get function field from attribute.
45

46
    NOTE: copied from https://github.com/HazyResearch/evaporate.
47

48
    """
49
    return re.sub(r"[^A-Za-z0-9]", "_", attribute)
50

51

52
def extract_field_dicts(result: str, text_chunk: str) -> Set:
53
    """Extract field dictionaries."""
54
    existing_fields = set()
55
    result = result.split("---")[0].strip("\n")
56
    results = result.split("\n")
57
    results = [r.strip("-").strip() for r in results]
58
    results = [r[2:].strip() if len(r) > 2 and r[1] == "." else r for r in results]
59
    for result in results:
60
        try:
61
            field = result.split(": ")[0].strip(":")
62
            value = ": ".join(result.split(": ")[1:])
63
        except Exception:
64
            print(f"Skipped: {result}")
65
            continue
66
        field_versions = [
67
            field,
68
            field.replace(" ", ""),
69
            field.replace("-", ""),
70
            field.replace("_", ""),
71
        ]
72
        if not any(f.lower() in text_chunk.lower() for f in field_versions):
73
            continue
74
        if not value:
75
            continue
76
        field = field.lower().strip("-").strip("_").strip(" ").strip(":")
77
        if field in existing_fields:
78
            continue
79
        existing_fields.add(field)
80

81
    return existing_fields
82

83

84
# since we define globals below
85
class EvaporateExtractor:
86
    """Wrapper around Evaporate.
87

88
    Evaporate is an open-source project from Stanford's AI Lab:
89
    https://github.com/HazyResearch/evaporate.
90
    Offering techniques for structured datapoint extraction.
91

92
    In the current version, we use the function generator
93
    from a set of documents.
94

95
    Args:
96
        service_context (Optional[ServiceContext]): Service Context to use.
97
    """
98

99
    def __init__(
100
        self,
101
        service_context: Optional[ServiceContext] = None,
102
        schema_id_prompt: Optional[SchemaIDPrompt] = None,
103
        fn_generate_prompt: Optional[FnGeneratePrompt] = None,
104
        field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
105
        expected_output_prefix_tmpl: str = DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL,
106
        verbose: bool = False,
107
    ) -> None:
108
        """Initialize params."""
109
        # TODO: take in an entire index instead of forming a response builder
110
        self._service_context = service_context or ServiceContext.from_defaults()
111
        self._schema_id_prompt = schema_id_prompt or SCHEMA_ID_PROMPT
112
        self._fn_generate_prompt = fn_generate_prompt or FN_GENERATION_PROMPT
113
        self._field_extract_query_tmpl = field_extract_query_tmpl
114
        self._expected_output_prefix_tmpl = expected_output_prefix_tmpl
115
        self._verbose = verbose
116

117
    def identify_fields(
118
        self, nodes: List[BaseNode], topic: str, fields_top_k: int = 5
119
    ) -> List:
120
        """Identify fields from nodes.
121

122
        Will extract fields independently per node, and then
123
        return the top k fields.
124

125
        Args:
126
            nodes (List[BaseNode]): List of nodes to extract fields from.
127
            topic (str): Topic to use for extraction.
128
            fields_top_k (int): Number of fields to return.
129

130
        """
131
        field2count: dict = defaultdict(int)
132
        for node in nodes:
133
            llm = self._service_context.llm
134
            result = llm.predict(
135
                self._schema_id_prompt,
136
                topic=topic,
137
                chunk=node.get_content(metadata_mode=MetadataMode.LLM),
138
            )
139

140
            existing_fields = extract_field_dicts(
141
                result, node.get_content(metadata_mode=MetadataMode.LLM)
142
            )
143

144
            for field in existing_fields:
145
                field2count[field] += 1
146

147
        sorted_tups: List[Tuple[str, int]] = sorted(
148
            field2count.items(), key=lambda x: x[1], reverse=True
149
        )
150
        sorted_fields = [f[0] for f in sorted_tups]
151
        return sorted_fields[:fields_top_k]
152

153
    def extract_fn_from_nodes(
154
        self, nodes: List[BaseNode], field: str, expected_output: Optional[Any] = None
155
    ) -> str:
156
        """Extract function from nodes."""
157
        # avoid circular import
158
        from llama_index.legacy.response_synthesizers import (
159
            ResponseMode,
160
            get_response_synthesizer,
161
        )
162

163
        function_field = get_function_field_from_attribute(field)
164
        # TODO: replace with new response synthesis module
165

166
        if expected_output is not None:
167
            expected_output_str = (
168
                f"{self._expected_output_prefix_tmpl}{expected_output!s}\n"
169
            )
170
        else:
171
            expected_output_str = ""
172

173
        qa_prompt = self._fn_generate_prompt.partial_format(
174
            attribute=field,
175
            function_field=function_field,
176
            expected_output_str=expected_output_str,
177
        )
178

179
        response_synthesizer = get_response_synthesizer(
180
            service_context=self._service_context,
181
            text_qa_template=qa_prompt,
182
            response_mode=ResponseMode.TREE_SUMMARIZE,
183
        )
184

185
        # ignore refine prompt for now
186
        query_str = self._field_extract_query_tmpl.format(field=function_field)
187
        query_bundle = QueryBundle(query_str=query_str)
188
        response = response_synthesizer.synthesize(
189
            query_bundle,
190
            [NodeWithScore(node=n, score=1.0) for n in nodes],
191
        )
192
        fn_str = f"""def get_{function_field}_field(text: str):
193
    \"""
194
    Function to extract {field}.
195
    \"""
196
    {response!s}
197
"""
198

199
        # format fn_str
200
        return_idx_list = [i for i, s in enumerate(fn_str.split("\n")) if "return" in s]
201
        if not return_idx_list:
202
            return ""
203

204
        return_idx = return_idx_list[0]
205
        fn_str = "\n".join(fn_str.split("\n")[: return_idx + 1])
206
        fn_str = "\n".join([s for s in fn_str.split("\n") if "print(" not in s])
207
        return "\n".join(
208
            [s for s in fn_str.split("\n") if s.startswith((" ", "\t", "def"))]
209
        )
210

211
    def run_fn_on_nodes(
212
        self, nodes: List[BaseNode], fn_str: str, field_name: str, num_timeouts: int = 1
213
    ) -> List:
214
        """Run function on nodes.
215

216
        Calls python exec().
217

218
        There are definitely security holes with this approach, use with caution.
219

220
        """
221
        function_field = get_function_field_from_attribute(field_name)
222
        results = []
223
        for node in nodes:
224
            global result
225
            global node_text
226
            node_text = node.get_content()  # type: ignore[name-defined]
227
            # this is temporary
228
            result = []  # type: ignore[name-defined]
229
            try:
230
                with time_limit(1):
231
                    exec(fn_str, globals())
232
                    exec(f"result = get_{function_field}_field(node_text)", globals())
233
            except TimeoutException:
234
                raise
235
            results.append(result)  # type: ignore[name-defined]
236
        return results
237

238
    def extract_datapoints_with_fn(
239
        self,
240
        nodes: List[BaseNode],
241
        topic: str,
242
        sample_k: int = 5,
243
        fields_top_k: int = 5,
244
    ) -> List[Dict]:
245
        """Extract datapoints from a list of nodes, given a topic."""
246
        idxs = list(range(len(nodes)))
247
        sample_k = min(sample_k, len(nodes))
248
        subset_idxs = random.sample(idxs, sample_k)
249
        subset_nodes = [nodes[si] for si in subset_idxs]
250

251
        # get existing fields
252
        existing_fields = self.identify_fields(
253
            subset_nodes, topic, fields_top_k=fields_top_k
254
        )
255

256
        # then, for each existing field, generate function
257
        function_dict = {}
258
        for field in existing_fields:
259
            fn = self.extract_fn_from_nodes(subset_nodes, field)
260
            function_dict[field] = fn
261

262
        # then, run function for all nodes
263
        result_dict = {}
264
        for field in existing_fields:
265
            result_list = self.run_fn_on_nodes(nodes, function_dict[field], field)
266
            result_dict[field] = result_list
267

268
        # convert into list of dictionaries
269
        result_list = []
270
        for i in range(len(nodes)):
271
            result_dict_i = {}
272
            for field in existing_fields:
273
                result_dict_i[field] = result_dict[field][i]
274
            result_list.append(result_dict_i)
275
        return result_list
276

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.