llama-index

Форк
0
1
import logging
2
from abc import abstractmethod
3
from typing import Any, Dict, Generic, List, Optional, Type
4

5
import pandas as pd
6

7
from llama_index.legacy.program.predefined.df import (
8
    DataFrameRow,
9
    DataFrameRowsOnly,
10
    DataFrameValuesPerColumn,
11
)
12
from llama_index.legacy.program.predefined.evaporate.extractor import EvaporateExtractor
13
from llama_index.legacy.program.predefined.evaporate.prompts import (
14
    DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
15
    FN_GENERATION_LIST_PROMPT,
16
    FnGeneratePrompt,
17
    SchemaIDPrompt,
18
)
19
from llama_index.legacy.schema import BaseNode, TextNode
20
from llama_index.legacy.service_context import ServiceContext
21
from llama_index.legacy.types import BasePydanticProgram, Model
22
from llama_index.legacy.utils import print_text
23

24
logger = logging.getLogger(__name__)
25

26

27
class BaseEvaporateProgram(BasePydanticProgram, Generic[Model]):
28
    """BaseEvaporate program.
29

30
    You should provide the fields you want to extract.
31
    Then when you call the program you should pass in a list of training_data nodes
32
    and a list of infer_data nodes. The program will call the EvaporateExtractor
33
    to synthesize a python function from the training data and then apply the function
34
    to the infer_data.
35
    """
36

37
    def __init__(
38
        self,
39
        extractor: EvaporateExtractor,
40
        fields_to_extract: Optional[List[str]] = None,
41
        fields_context: Optional[Dict[str, Any]] = None,
42
        nodes_to_fit: Optional[List[BaseNode]] = None,
43
        verbose: bool = False,
44
    ) -> None:
45
        """Init params."""
46
        self._extractor = extractor
47
        self._fields = fields_to_extract or []
48
        self._fields_context = fields_context or {}
49
        # NOTE: this will change with each call to `fit`
50
        self._field_fns: Dict[str, str] = {}
51
        self._verbose = verbose
52

53
        # if nodes_to_fit is not None, then fit extractor
54
        if nodes_to_fit is not None:
55
            self._field_fns = self.fit_fields(nodes_to_fit)
56

57
    @classmethod
58
    def from_defaults(
59
        cls,
60
        fields_to_extract: Optional[List[str]] = None,
61
        fields_context: Optional[Dict[str, Any]] = None,
62
        service_context: Optional[ServiceContext] = None,
63
        schema_id_prompt: Optional[SchemaIDPrompt] = None,
64
        fn_generate_prompt: Optional[FnGeneratePrompt] = None,
65
        field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
66
        nodes_to_fit: Optional[List[BaseNode]] = None,
67
        verbose: bool = False,
68
    ) -> "BaseEvaporateProgram":
69
        """Evaporate program."""
70
        extractor = EvaporateExtractor(
71
            service_context=service_context,
72
            schema_id_prompt=schema_id_prompt,
73
            fn_generate_prompt=fn_generate_prompt,
74
            field_extract_query_tmpl=field_extract_query_tmpl,
75
        )
76
        return cls(
77
            extractor,
78
            fields_to_extract=fields_to_extract,
79
            fields_context=fields_context,
80
            nodes_to_fit=nodes_to_fit,
81
            verbose=verbose,
82
        )
83

84
    @property
85
    def extractor(self) -> EvaporateExtractor:
86
        """Extractor."""
87
        return self._extractor
88

89
    def get_function_str(self, field: str) -> str:
90
        """Get function string."""
91
        return self._field_fns[field]
92

93
    def set_fields_to_extract(self, fields: List[str]) -> None:
94
        """Set fields to extract."""
95
        self._fields = fields
96

97
    def fit_fields(
98
        self,
99
        nodes: List[BaseNode],
100
        inplace: bool = True,
101
    ) -> Dict[str, str]:
102
        """Fit on all fields."""
103
        if len(self._fields) == 0:
104
            raise ValueError("Must provide at least one field to extract.")
105

106
        field_fns = {}
107
        for field in self._fields:
108
            field_context = self._fields_context.get(field, None)
109
            field_fns[field] = self.fit(
110
                nodes, field, field_context=field_context, inplace=inplace
111
            )
112
        return field_fns
113

114
    @abstractmethod
115
    def fit(
116
        self,
117
        nodes: List[BaseNode],
118
        field: str,
119
        field_context: Optional[Any] = None,
120
        expected_output: Optional[Any] = None,
121
        inplace: bool = True,
122
    ) -> str:
123
        """Given the input Nodes and fields, synthesize the python code."""
124

125

126
class DFEvaporateProgram(BaseEvaporateProgram[DataFrameRowsOnly]):
127
    """Evaporate DF program.
128

129
    Given a set of fields, extracts a dataframe from a set of nodes.
130
    Each node corresponds to a row in the dataframe - each value in the row
131
    corresponds to a field value.
132

133
    """
134

135
    def fit(
136
        self,
137
        nodes: List[BaseNode],
138
        field: str,
139
        field_context: Optional[Any] = None,
140
        expected_output: Optional[Any] = None,
141
        inplace: bool = True,
142
    ) -> str:
143
        """Given the input Nodes and fields, synthesize the python code."""
144
        fn = self._extractor.extract_fn_from_nodes(nodes, field)
145
        logger.debug(f"Extracted function: {fn}")
146
        if inplace:
147
            self._field_fns[field] = fn
148
        return fn
149

150
    def _inference(
151
        self, nodes: List[BaseNode], fn_str: str, field_name: str
152
    ) -> List[Any]:
153
        """Given the input, call the python code and return the result."""
154
        results = self._extractor.run_fn_on_nodes(nodes, fn_str, field_name)
155
        logger.debug(f"Results: {results}")
156
        return results
157

158
    @property
159
    def output_cls(self) -> Type[DataFrameRowsOnly]:
160
        """Output class."""
161
        return DataFrameRowsOnly
162

163
    def __call__(self, *args: Any, **kwds: Any) -> DataFrameRowsOnly:
164
        """Call evaporate on inference data."""
165
        # TODO: either specify `nodes` or `texts` in kwds
166
        if "nodes" in kwds:
167
            nodes = kwds["nodes"]
168
        elif "texts" in kwds:
169
            nodes = [TextNode(text=t) for t in kwds["texts"]]
170
        else:
171
            raise ValueError("Must provide either `nodes` or `texts`.")
172

173
        col_dict = {}
174
        for field in self._fields:
175
            col_dict[field] = self._inference(nodes, self._field_fns[field], field)
176

177
        df = pd.DataFrame(col_dict, columns=self._fields)
178

179
        # convert pd.DataFrame to DataFrameRowsOnly
180
        df_row_objs = []
181
        for row_arr in df.values:
182
            df_row_objs.append(DataFrameRow(row_values=list(row_arr)))
183
        return DataFrameRowsOnly(rows=df_row_objs)
184

185

186
class MultiValueEvaporateProgram(BaseEvaporateProgram[DataFrameValuesPerColumn]):
187
    """Multi-Value Evaporate program.
188

189
    Given a set of fields, and texts extracts a list of `DataFrameRow` objects across
190
    that texts.
191
    Each DataFrameRow corresponds to a field, and each value in the row corresponds to
192
    a value for the field.
193

194
    Difference with DFEvaporateProgram is that 1) each DataFrameRow
195
    is column-oriented (instead of row-oriented), and 2)
196
    each DataFrameRow can be variable length (not guaranteed to have 1 value per
197
    node).
198

199
    """
200

201
    @classmethod
202
    def from_defaults(
203
        cls,
204
        fields_to_extract: Optional[List[str]] = None,
205
        fields_context: Optional[Dict[str, Any]] = None,
206
        service_context: Optional[ServiceContext] = None,
207
        schema_id_prompt: Optional[SchemaIDPrompt] = None,
208
        fn_generate_prompt: Optional[FnGeneratePrompt] = None,
209
        field_extract_query_tmpl: str = DEFAULT_FIELD_EXTRACT_QUERY_TMPL,
210
        nodes_to_fit: Optional[List[BaseNode]] = None,
211
        verbose: bool = False,
212
    ) -> "BaseEvaporateProgram":
213
        # modify the default function generate prompt to return a list
214
        fn_generate_prompt = fn_generate_prompt or FN_GENERATION_LIST_PROMPT
215
        return super().from_defaults(
216
            fields_to_extract=fields_to_extract,
217
            fields_context=fields_context,
218
            service_context=service_context,
219
            schema_id_prompt=schema_id_prompt,
220
            fn_generate_prompt=fn_generate_prompt,
221
            field_extract_query_tmpl=field_extract_query_tmpl,
222
            nodes_to_fit=nodes_to_fit,
223
            verbose=verbose,
224
        )
225

226
    def fit(
227
        self,
228
        nodes: List[BaseNode],
229
        field: str,
230
        field_context: Optional[Any] = None,
231
        expected_output: Optional[Any] = None,
232
        inplace: bool = True,
233
    ) -> str:
234
        """Given the input Nodes and fields, synthesize the python code."""
235
        fn = self._extractor.extract_fn_from_nodes(
236
            nodes, field, expected_output=expected_output
237
        )
238
        logger.debug(f"Extracted function: {fn}")
239
        if self._verbose:
240
            print_text(f"Extracted function: {fn}\n", color="blue")
241
        if inplace:
242
            self._field_fns[field] = fn
243
        return fn
244

245
    @property
246
    def output_cls(self) -> Type[DataFrameValuesPerColumn]:
247
        """Output class."""
248
        return DataFrameValuesPerColumn
249

250
    def _inference(
251
        self, nodes: List[BaseNode], fn_str: str, field_name: str
252
    ) -> List[Any]:
253
        """Given the input, call the python code and return the result."""
254
        results_by_node = self._extractor.run_fn_on_nodes(nodes, fn_str, field_name)
255
        # flatten results
256
        return [r for results in results_by_node for r in results]
257

258
    def __call__(self, *args: Any, **kwds: Any) -> DataFrameValuesPerColumn:
259
        """Call evaporate on inference data."""
260
        # TODO: either specify `nodes` or `texts` in kwds
261
        if "nodes" in kwds:
262
            nodes = kwds["nodes"]
263
        elif "texts" in kwds:
264
            nodes = [TextNode(text=t) for t in kwds["texts"]]
265
        else:
266
            raise ValueError("Must provide either `nodes` or `texts`.")
267

268
        col_dict = {}
269
        for field in self._fields:
270
            col_dict[field] = self._inference(nodes, self._field_fns[field], field)
271

272
        # convert col_dict to list of DataFrameRow objects
273
        df_row_objs = []
274
        for field in self._fields:
275
            df_row_objs.append(DataFrameRow(row_values=col_dict[field]))
276

277
        return DataFrameValuesPerColumn(columns=df_row_objs)
278

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.