llama-index
275 строк · 9.1 Кб
1import random
2import re
3import signal
4from collections import defaultdict
5from contextlib import contextmanager
6from typing import Any, Dict, List, Optional, Set, Tuple
7
8from llama_index.legacy.program.predefined.evaporate.prompts import (
9DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL,
10DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
11FN_GENERATION_PROMPT,
12SCHEMA_ID_PROMPT,
13FnGeneratePrompt,
14SchemaIDPrompt,
15)
16from llama_index.legacy.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle
17from llama_index.legacy.service_context import ServiceContext
18
19
20class TimeoutException(Exception):
21pass
22
23
24@contextmanager
25def time_limit(seconds: int) -> Any:
26"""Time limit context manager.
27
28NOTE: copied from https://github.com/HazyResearch/evaporate.
29
30"""
31
32def signal_handler(signum: Any, frame: Any) -> Any:
33raise TimeoutException("Timed out!")
34
35signal.signal(signal.SIGALRM, signal_handler)
36signal.alarm(seconds)
37try:
38yield
39finally:
40signal.alarm(0)
41
42
43def get_function_field_from_attribute(attribute: str) -> str:
44"""Get function field from attribute.
45
46NOTE: copied from https://github.com/HazyResearch/evaporate.
47
48"""
49return re.sub(r"[^A-Za-z0-9]", "_", attribute)
50
51
52def extract_field_dicts(result: str, text_chunk: str) -> Set:
53"""Extract field dictionaries."""
54existing_fields = set()
55result = result.split("---")[0].strip("\n")
56results = result.split("\n")
57results = [r.strip("-").strip() for r in results]
58results = [r[2:].strip() if len(r) > 2 and r[1] == "." else r for r in results]
59for result in results:
60try:
61field = result.split(": ")[0].strip(":")
62value = ": ".join(result.split(": ")[1:])
63except Exception:
64print(f"Skipped: {result}")
65continue
66field_versions = [
67field,
68field.replace(" ", ""),
69field.replace("-", ""),
70field.replace("_", ""),
71]
72if not any(f.lower() in text_chunk.lower() for f in field_versions):
73continue
74if not value:
75continue
76field = field.lower().strip("-").strip("_").strip(" ").strip(":")
77if field in existing_fields:
78continue
79existing_fields.add(field)
80
81return existing_fields
82
83
84# since we define globals below
85class EvaporateExtractor:
86"""Wrapper around Evaporate.
87
88Evaporate is an open-source project from Stanford's AI Lab:
89https://github.com/HazyResearch/evaporate.
90Offering techniques for structured datapoint extraction.
91
92In the current version, we use the function generator
93from a set of documents.
94
95Args:
96service_context (Optional[ServiceContext]): Service Context to use.
97"""
98
99def __init__(
100self,
101service_context: Optional[ServiceContext] = None,
102schema_id_prompt: Optional[SchemaIDPrompt] = None,
103fn_generate_prompt: Optional[FnGeneratePrompt] = None,
104field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
105expected_output_prefix_tmpl: str = DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL,
106verbose: bool = False,
107) -> None:
108"""Initialize params."""
109# TODO: take in an entire index instead of forming a response builder
110self._service_context = service_context or ServiceContext.from_defaults()
111self._schema_id_prompt = schema_id_prompt or SCHEMA_ID_PROMPT
112self._fn_generate_prompt = fn_generate_prompt or FN_GENERATION_PROMPT
113self._field_extract_query_tmpl = field_extract_query_tmpl
114self._expected_output_prefix_tmpl = expected_output_prefix_tmpl
115self._verbose = verbose
116
117def identify_fields(
118self, nodes: List[BaseNode], topic: str, fields_top_k: int = 5
119) -> List:
120"""Identify fields from nodes.
121
122Will extract fields independently per node, and then
123return the top k fields.
124
125Args:
126nodes (List[BaseNode]): List of nodes to extract fields from.
127topic (str): Topic to use for extraction.
128fields_top_k (int): Number of fields to return.
129
130"""
131field2count: dict = defaultdict(int)
132for node in nodes:
133llm = self._service_context.llm
134result = llm.predict(
135self._schema_id_prompt,
136topic=topic,
137chunk=node.get_content(metadata_mode=MetadataMode.LLM),
138)
139
140existing_fields = extract_field_dicts(
141result, node.get_content(metadata_mode=MetadataMode.LLM)
142)
143
144for field in existing_fields:
145field2count[field] += 1
146
147sorted_tups: List[Tuple[str, int]] = sorted(
148field2count.items(), key=lambda x: x[1], reverse=True
149)
150sorted_fields = [f[0] for f in sorted_tups]
151return sorted_fields[:fields_top_k]
152
153def extract_fn_from_nodes(
154self, nodes: List[BaseNode], field: str, expected_output: Optional[Any] = None
155) -> str:
156"""Extract function from nodes."""
157# avoid circular import
158from llama_index.legacy.response_synthesizers import (
159ResponseMode,
160get_response_synthesizer,
161)
162
163function_field = get_function_field_from_attribute(field)
164# TODO: replace with new response synthesis module
165
166if expected_output is not None:
167expected_output_str = (
168f"{self._expected_output_prefix_tmpl}{expected_output!s}\n"
169)
170else:
171expected_output_str = ""
172
173qa_prompt = self._fn_generate_prompt.partial_format(
174attribute=field,
175function_field=function_field,
176expected_output_str=expected_output_str,
177)
178
179response_synthesizer = get_response_synthesizer(
180service_context=self._service_context,
181text_qa_template=qa_prompt,
182response_mode=ResponseMode.TREE_SUMMARIZE,
183)
184
185# ignore refine prompt for now
186query_str = self._field_extract_query_tmpl.format(field=function_field)
187query_bundle = QueryBundle(query_str=query_str)
188response = response_synthesizer.synthesize(
189query_bundle,
190[NodeWithScore(node=n, score=1.0) for n in nodes],
191)
192fn_str = f"""def get_{function_field}_field(text: str):
193\"""
194Function to extract {field}.
195\"""
196{response!s}
197"""
198
199# format fn_str
200return_idx_list = [i for i, s in enumerate(fn_str.split("\n")) if "return" in s]
201if not return_idx_list:
202return ""
203
204return_idx = return_idx_list[0]
205fn_str = "\n".join(fn_str.split("\n")[: return_idx + 1])
206fn_str = "\n".join([s for s in fn_str.split("\n") if "print(" not in s])
207return "\n".join(
208[s for s in fn_str.split("\n") if s.startswith((" ", "\t", "def"))]
209)
210
211def run_fn_on_nodes(
212self, nodes: List[BaseNode], fn_str: str, field_name: str, num_timeouts: int = 1
213) -> List:
214"""Run function on nodes.
215
216Calls python exec().
217
218There are definitely security holes with this approach, use with caution.
219
220"""
221function_field = get_function_field_from_attribute(field_name)
222results = []
223for node in nodes:
224global result
225global node_text
226node_text = node.get_content() # type: ignore[name-defined]
227# this is temporary
228result = [] # type: ignore[name-defined]
229try:
230with time_limit(1):
231exec(fn_str, globals())
232exec(f"result = get_{function_field}_field(node_text)", globals())
233except TimeoutException:
234raise
235results.append(result) # type: ignore[name-defined]
236return results
237
238def extract_datapoints_with_fn(
239self,
240nodes: List[BaseNode],
241topic: str,
242sample_k: int = 5,
243fields_top_k: int = 5,
244) -> List[Dict]:
245"""Extract datapoints from a list of nodes, given a topic."""
246idxs = list(range(len(nodes)))
247sample_k = min(sample_k, len(nodes))
248subset_idxs = random.sample(idxs, sample_k)
249subset_nodes = [nodes[si] for si in subset_idxs]
250
251# get existing fields
252existing_fields = self.identify_fields(
253subset_nodes, topic, fields_top_k=fields_top_k
254)
255
256# then, for each existing field, generate function
257function_dict = {}
258for field in existing_fields:
259fn = self.extract_fn_from_nodes(subset_nodes, field)
260function_dict[field] = fn
261
262# then, run function for all nodes
263result_dict = {}
264for field in existing_fields:
265result_list = self.run_fn_on_nodes(nodes, function_dict[field], field)
266result_dict[field] = result_list
267
268# convert into list of dictionaries
269result_list = []
270for i in range(len(nodes)):
271result_dict_i = {}
272for field in existing_fields:
273result_dict_i[field] = result_dict[field][i]
274result_list.append(result_dict_i)
275return result_list
276