llama-index
277 строк · 9.7 Кб
1import logging
2from abc import abstractmethod
3from typing import Any, Dict, Generic, List, Optional, Type
4
5import pandas as pd
6
7from llama_index.legacy.program.predefined.df import (
8DataFrameRow,
9DataFrameRowsOnly,
10DataFrameValuesPerColumn,
11)
12from llama_index.legacy.program.predefined.evaporate.extractor import EvaporateExtractor
13from llama_index.legacy.program.predefined.evaporate.prompts import (
14DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
15FN_GENERATION_LIST_PROMPT,
16FnGeneratePrompt,
17SchemaIDPrompt,
18)
19from llama_index.legacy.schema import BaseNode, TextNode
20from llama_index.legacy.service_context import ServiceContext
21from llama_index.legacy.types import BasePydanticProgram, Model
22from llama_index.legacy.utils import print_text
23
24logger = logging.getLogger(__name__)
25
26
27class BaseEvaporateProgram(BasePydanticProgram, Generic[Model]):
28"""BaseEvaporate program.
29
30You should provide the fields you want to extract.
31Then when you call the program you should pass in a list of training_data nodes
32and a list of infer_data nodes. The program will call the EvaporateExtractor
33to synthesize a python function from the training data and then apply the function
34to the infer_data.
35"""
36
37def __init__(
38self,
39extractor: EvaporateExtractor,
40fields_to_extract: Optional[List[str]] = None,
41fields_context: Optional[Dict[str, Any]] = None,
42nodes_to_fit: Optional[List[BaseNode]] = None,
43verbose: bool = False,
44) -> None:
45"""Init params."""
46self._extractor = extractor
47self._fields = fields_to_extract or []
48self._fields_context = fields_context or {}
49# NOTE: this will change with each call to `fit`
50self._field_fns: Dict[str, str] = {}
51self._verbose = verbose
52
53# if nodes_to_fit is not None, then fit extractor
54if nodes_to_fit is not None:
55self._field_fns = self.fit_fields(nodes_to_fit)
56
57@classmethod
58def from_defaults(
59cls,
60fields_to_extract: Optional[List[str]] = None,
61fields_context: Optional[Dict[str, Any]] = None,
62service_context: Optional[ServiceContext] = None,
63schema_id_prompt: Optional[SchemaIDPrompt] = None,
64fn_generate_prompt: Optional[FnGeneratePrompt] = None,
65field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
66nodes_to_fit: Optional[List[BaseNode]] = None,
67verbose: bool = False,
68) -> "BaseEvaporateProgram":
69"""Evaporate program."""
70extractor = EvaporateExtractor(
71service_context=service_context,
72schema_id_prompt=schema_id_prompt,
73fn_generate_prompt=fn_generate_prompt,
74field_extract_query_tmpl=field_extract_query_tmpl,
75)
76return cls(
77extractor,
78fields_to_extract=fields_to_extract,
79fields_context=fields_context,
80nodes_to_fit=nodes_to_fit,
81verbose=verbose,
82)
83
84@property
85def extractor(self) -> EvaporateExtractor:
86"""Extractor."""
87return self._extractor
88
89def get_function_str(self, field: str) -> str:
90"""Get function string."""
91return self._field_fns[field]
92
93def set_fields_to_extract(self, fields: List[str]) -> None:
94"""Set fields to extract."""
95self._fields = fields
96
97def fit_fields(
98self,
99nodes: List[BaseNode],
100inplace: bool = True,
101) -> Dict[str, str]:
102"""Fit on all fields."""
103if len(self._fields) == 0:
104raise ValueError("Must provide at least one field to extract.")
105
106field_fns = {}
107for field in self._fields:
108field_context = self._fields_context.get(field, None)
109field_fns[field] = self.fit(
110nodes, field, field_context=field_context, inplace=inplace
111)
112return field_fns
113
114@abstractmethod
115def fit(
116self,
117nodes: List[BaseNode],
118field: str,
119field_context: Optional[Any] = None,
120expected_output: Optional[Any] = None,
121inplace: bool = True,
122) -> str:
123"""Given the input Nodes and fields, synthesize the python code."""
124
125
126class DFEvaporateProgram(BaseEvaporateProgram[DataFrameRowsOnly]):
127"""Evaporate DF program.
128
129Given a set of fields, extracts a dataframe from a set of nodes.
130Each node corresponds to a row in the dataframe - each value in the row
131corresponds to a field value.
132
133"""
134
135def fit(
136self,
137nodes: List[BaseNode],
138field: str,
139field_context: Optional[Any] = None,
140expected_output: Optional[Any] = None,
141inplace: bool = True,
142) -> str:
143"""Given the input Nodes and fields, synthesize the python code."""
144fn = self._extractor.extract_fn_from_nodes(nodes, field)
145logger.debug(f"Extracted function: {fn}")
146if inplace:
147self._field_fns[field] = fn
148return fn
149
150def _inference(
151self, nodes: List[BaseNode], fn_str: str, field_name: str
152) -> List[Any]:
153"""Given the input, call the python code and return the result."""
154results = self._extractor.run_fn_on_nodes(nodes, fn_str, field_name)
155logger.debug(f"Results: {results}")
156return results
157
158@property
159def output_cls(self) -> Type[DataFrameRowsOnly]:
160"""Output class."""
161return DataFrameRowsOnly
162
163def __call__(self, *args: Any, **kwds: Any) -> DataFrameRowsOnly:
164"""Call evaporate on inference data."""
165# TODO: either specify `nodes` or `texts` in kwds
166if "nodes" in kwds:
167nodes = kwds["nodes"]
168elif "texts" in kwds:
169nodes = [TextNode(text=t) for t in kwds["texts"]]
170else:
171raise ValueError("Must provide either `nodes` or `texts`.")
172
173col_dict = {}
174for field in self._fields:
175col_dict[field] = self._inference(nodes, self._field_fns[field], field)
176
177df = pd.DataFrame(col_dict, columns=self._fields)
178
179# convert pd.DataFrame to DataFrameRowsOnly
180df_row_objs = []
181for row_arr in df.values:
182df_row_objs.append(DataFrameRow(row_values=list(row_arr)))
183return DataFrameRowsOnly(rows=df_row_objs)
184
185
186class MultiValueEvaporateProgram(BaseEvaporateProgram[DataFrameValuesPerColumn]):
187"""Multi-Value Evaporate program.
188
189Given a set of fields, and texts extracts a list of `DataFrameRow` objects across
190that texts.
191Each DataFrameRow corresponds to a field, and each value in the row corresponds to
192a value for the field.
193
194Difference with DFEvaporateProgram is that 1) each DataFrameRow
195is column-oriented (instead of row-oriented), and 2)
196each DataFrameRow can be variable length (not guaranteed to have 1 value per
197node).
198
199"""
200
201@classmethod
202def from_defaults(
203cls,
204fields_to_extract: Optional[List[str]] = None,
205fields_context: Optional[Dict[str, Any]] = None,
206service_context: Optional[ServiceContext] = None,
207schema_id_prompt: Optional[SchemaIDPrompt] = None,
208fn_generate_prompt: Optional[FnGeneratePrompt] = None,
209field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
210nodes_to_fit: Optional[List[BaseNode]] = None,
211verbose: bool = False,
212) -> "BaseEvaporateProgram":
213# modify the default function generate prompt to return a list
214fn_generate_prompt = fn_generate_prompt or FN_GENERATION_LIST_PROMPT
215return super().from_defaults(
216fields_to_extract=fields_to_extract,
217fields_context=fields_context,
218service_context=service_context,
219schema_id_prompt=schema_id_prompt,
220fn_generate_prompt=fn_generate_prompt,
221field_extract_query_tmpl=field_extract_query_tmpl,
222nodes_to_fit=nodes_to_fit,
223verbose=verbose,
224)
225
226def fit(
227self,
228nodes: List[BaseNode],
229field: str,
230field_context: Optional[Any] = None,
231expected_output: Optional[Any] = None,
232inplace: bool = True,
233) -> str:
234"""Given the input Nodes and fields, synthesize the python code."""
235fn = self._extractor.extract_fn_from_nodes(
236nodes, field, expected_output=expected_output
237)
238logger.debug(f"Extracted function: {fn}")
239if self._verbose:
240print_text(f"Extracted function: {fn}\n", color="blue")
241if inplace:
242self._field_fns[field] = fn
243return fn
244
245@property
246def output_cls(self) -> Type[DataFrameValuesPerColumn]:
247"""Output class."""
248return DataFrameValuesPerColumn
249
250def _inference(
251self, nodes: List[BaseNode], fn_str: str, field_name: str
252) -> List[Any]:
253"""Given the input, call the python code and return the result."""
254results_by_node = self._extractor.run_fn_on_nodes(nodes, fn_str, field_name)
255# flatten results
256return [r for results in results_by_node for r in results]
257
258def __call__(self, *args: Any, **kwds: Any) -> DataFrameValuesPerColumn:
259"""Call evaporate on inference data."""
260# TODO: either specify `nodes` or `texts` in kwds
261if "nodes" in kwds:
262nodes = kwds["nodes"]
263elif "texts" in kwds:
264nodes = [TextNode(text=t) for t in kwds["texts"]]
265else:
266raise ValueError("Must provide either `nodes` or `texts`.")
267
268col_dict = {}
269for field in self._fields:
270col_dict[field] = self._inference(nodes, self._field_fns[field], field)
271
272# convert col_dict to list of DataFrameRow objects
273df_row_objs = []
274for field in self._fields:
275df_row_objs.append(DataFrameRow(row_values=col_dict[field]))
276
277return DataFrameValuesPerColumn(columns=df_row_objs)
278