llama-index

Форк
0
672 строки · 23.9 Кб
1
"""Query Pipeline."""
2

3
import json
4
import uuid
5
from typing import (
6
    Any,
7
    Callable,
8
    Dict,
9
    List,
10
    Optional,
11
    Sequence,
12
    Tuple,
13
    Union,
14
    cast,
15
    get_args,
16
)
17

18
import networkx
19

20
from llama_index.legacy.async_utils import run_jobs
21
from llama_index.legacy.bridge.pydantic import Field
22
from llama_index.legacy.callbacks import CallbackManager
23
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
24
from llama_index.legacy.core.query_pipeline.query_component import (
25
    QUERY_COMPONENT_TYPE,
26
    ChainableMixin,
27
    InputKeys,
28
    Link,
29
    OutputKeys,
30
    QueryComponent,
31
)
32
from llama_index.legacy.utils import print_text
33

34

35
def get_output(
36
    src_key: Optional[str],
37
    output_dict: Dict[str, Any],
38
) -> Any:
39
    """Add input to module deps inputs."""
40
    # get relevant output from link
41
    if src_key is None:
42
        # ensure that output_dict only has one key
43
        if len(output_dict) != 1:
44
            raise ValueError("Output dict must have exactly one key.")
45
        output = next(iter(output_dict.values()))
46
    else:
47
        output = output_dict[src_key]
48
    return output
49

50

51
def add_output_to_module_inputs(
52
    dest_key: str,
53
    output: Any,
54
    module: QueryComponent,
55
    module_inputs: Dict[str, Any],
56
) -> None:
57
    """Add input to module deps inputs."""
58
    # now attach output to relevant input key for module
59
    if dest_key is None:
60
        free_keys = module.free_req_input_keys
61
        # ensure that there is only one remaining key given partials
62
        if len(free_keys) != 1:
63
            raise ValueError(
64
                "Module input keys must have exactly one key if "
65
                "dest_key is not specified. Remaining keys: "
66
                f"in module: {free_keys}"
67
            )
68
        module_inputs[next(iter(free_keys))] = output
69
    else:
70
        module_inputs[dest_key] = output
71

72

73
def print_debug_input(
74
    module_key: str,
75
    input: Dict[str, Any],
76
    val_str_len: int = 200,
77
) -> None:
78
    """Print debug input."""
79
    output = f"> Running module {module_key} with input: \n"
80
    for key, value in input.items():
81
        # stringify and truncate output
82
        val_str = (
83
            str(value)[:val_str_len] + "..."
84
            if len(str(value)) > val_str_len
85
            else str(value)
86
        )
87
        output += f"{key}: {val_str}\n"
88

89
    print_text(output + "\n", color="llama_lavender")
90

91

92
def print_debug_input_multi(
93
    module_keys: List[str],
94
    module_inputs: List[Dict[str, Any]],
95
    val_str_len: int = 200,
96
) -> None:
97
    """Print debug input."""
98
    output = f"> Running modules and inputs in parallel: \n"
99
    for module_key, input in zip(module_keys, module_inputs):
100
        cur_output = f"Module key: {module_key}. Input: \n"
101
        for key, value in input.items():
102
            # stringify and truncate output
103
            val_str = (
104
                str(value)[:val_str_len] + "..."
105
                if len(str(value)) > val_str_len
106
                else str(value)
107
            )
108
            cur_output += f"{key}: {val_str}\n"
109
        output += cur_output + "\n"
110

111
    print_text(output + "\n", color="llama_lavender")
112

113

114
# Function to clean non-serializable attributes and return a copy of the graph
115
# https://stackoverflow.com/questions/23268421/networkx-how-to-access-attributes-of-objects-as-nodes
116
def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
117
    # Create a deep copy of the graph to preserve the original
118
    graph_copy = graph.copy()
119

120
    # Iterate over nodes and clean attributes
121
    for node, attributes in graph_copy.nodes(data=True):
122
        for key, value in list(attributes.items()):
123
            if callable(value):  # Checks if the value is a function
124
                del attributes[key]  # Remove the attribute if it's non-serializable
125

126
    # Similarly, you can extend this to clean edge attributes if necessary
127
    for u, v, attributes in graph_copy.edges(data=True):
128
        for key, value in list(attributes.items()):
129
            if callable(value):  # Checks if the value is a function
130
                del attributes[key]  # Remove the attribute if it's non-serializable
131

132
    return graph_copy
133

134

135
CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str]
136

137

138
class QueryPipeline(QueryComponent):
139
    """A query pipeline that can allow arbitrary chaining of different modules.
140

141
    A pipeline itself is a query component, and can be used as a module in another pipeline.
142

143
    """
144

145
    callback_manager: CallbackManager = Field(
146
        default_factory=lambda: CallbackManager([]), exclude=True
147
    )
148

149
    module_dict: Dict[str, QueryComponent] = Field(
150
        default_factory=dict, description="The modules in the pipeline."
151
    )
152
    dag: networkx.MultiDiGraph = Field(
153
        default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline."
154
    )
155
    verbose: bool = Field(
156
        default=False, description="Whether to print intermediate steps."
157
    )
158
    show_progress: bool = Field(
159
        default=False,
160
        description="Whether to show progress bar (currently async only).",
161
    )
162
    num_workers: int = Field(
163
        default=4, description="Number of workers to use (currently async only)."
164
    )
165

166
    class Config:
167
        arbitrary_types_allowed = True
168

169
    def __init__(
170
        self,
171
        callback_manager: Optional[CallbackManager] = None,
172
        chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None,
173
        modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None,
174
        links: Optional[List[Link]] = None,
175
        **kwargs: Any,
176
    ):
177
        super().__init__(
178
            callback_manager=callback_manager or CallbackManager([]),
179
            **kwargs,
180
        )
181

182
        self._init_graph(chain=chain, modules=modules, links=links)
183

184
    def _init_graph(
185
        self,
186
        chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None,
187
        modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None,
188
        links: Optional[List[Link]] = None,
189
    ) -> None:
190
        """Initialize graph."""
191
        if chain is not None:
192
            if modules is not None or links is not None:
193
                raise ValueError("Cannot specify both chain and modules/links in init.")
194
            self.add_chain(chain)
195
        elif modules is not None:
196
            self.add_modules(modules)
197
            if links is not None:
198
                for link in links:
199
                    self.add_link(**link.dict())
200

201
    def add_chain(self, chain: Sequence[CHAIN_COMPONENT_TYPE]) -> None:
202
        """Add a chain of modules to the pipeline.
203

204
        This is a special form of pipeline that is purely sequential/linear.
205
        This allows a more concise way of specifying a pipeline.
206

207
        """
208
        # first add all modules
209
        module_keys: List[str] = []
210
        for module in chain:
211
            if isinstance(module, get_args(QUERY_COMPONENT_TYPE)):
212
                module_key = str(uuid.uuid4())
213
                self.add(module_key, cast(QUERY_COMPONENT_TYPE, module))
214
                module_keys.append(module_key)
215
            elif isinstance(module, str):
216
                module_keys.append(module)
217
            else:
218
                raise ValueError("Chain must be a sequence of modules or module keys.")
219

220
        # then add all links
221
        for i in range(len(chain) - 1):
222
            self.add_link(src=module_keys[i], dest=module_keys[i + 1])
223

224
    def add_links(
225
        self,
226
        links: List[Link],
227
    ) -> None:
228
        """Add links to the pipeline."""
229
        for link in links:
230
            if isinstance(link, Link):
231
                self.add_link(**link.dict())
232
            else:
233
                raise ValueError("Link must be of type `Link` or `ConditionalLinks`.")
234

235
    def add_modules(self, module_dict: Dict[str, QUERY_COMPONENT_TYPE]) -> None:
236
        """Add modules to the pipeline."""
237
        for module_key, module in module_dict.items():
238
            self.add(module_key, module)
239

240
    def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None:
241
        """Add a module to the pipeline."""
242
        # if already exists, raise error
243
        if module_key in self.module_dict:
244
            raise ValueError(f"Module {module_key} already exists in pipeline.")
245

246
        if isinstance(module, ChainableMixin):
247
            module = module.as_query_component()
248
        else:
249
            pass
250

251
        self.module_dict[module_key] = cast(QueryComponent, module)
252
        self.dag.add_node(module_key)
253

254
    def add_link(
255
        self,
256
        src: str,
257
        dest: str,
258
        src_key: Optional[str] = None,
259
        dest_key: Optional[str] = None,
260
        condition_fn: Optional[Callable] = None,
261
        input_fn: Optional[Callable] = None,
262
    ) -> None:
263
        """Add a link between two modules."""
264
        if src not in self.module_dict:
265
            raise ValueError(f"Module {src} does not exist in pipeline.")
266
        self.dag.add_edge(
267
            src,
268
            dest,
269
            src_key=src_key,
270
            dest_key=dest_key,
271
            condition_fn=condition_fn,
272
            input_fn=input_fn,
273
        )
274

275
    def get_root_keys(self) -> List[str]:
276
        """Get root keys."""
277
        return self._get_root_keys()
278

279
    def get_leaf_keys(self) -> List[str]:
280
        """Get leaf keys."""
281
        return self._get_leaf_keys()
282

283
    def _get_root_keys(self) -> List[str]:
284
        """Get root keys."""
285
        return [v for v, d in self.dag.in_degree() if d == 0]
286

287
    def _get_leaf_keys(self) -> List[str]:
288
        """Get leaf keys."""
289
        # get all modules without downstream dependencies
290
        return [v for v, d in self.dag.out_degree() if d == 0]
291

292
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
293
        """Set callback manager."""
294
        # go through every module in module dict and set callback manager
295
        self.callback_manager = callback_manager
296
        for module in self.module_dict.values():
297
            module.set_callback_manager(callback_manager)
298

299
    def run(
300
        self,
301
        *args: Any,
302
        return_values_direct: bool = True,
303
        callback_manager: Optional[CallbackManager] = None,
304
        **kwargs: Any,
305
    ) -> Any:
306
        """Run the pipeline."""
307
        # first set callback manager
308
        callback_manager = callback_manager or self.callback_manager
309
        self.set_callback_manager(callback_manager)
310
        with self.callback_manager.as_trace("query"):
311
            # try to get query payload
312
            try:
313
                query_payload = json.dumps(kwargs)
314
            except TypeError:
315
                query_payload = json.dumps(str(kwargs))
316
            with self.callback_manager.event(
317
                CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload}
318
            ) as query_event:
319
                return self._run(
320
                    *args, return_values_direct=return_values_direct, **kwargs
321
                )
322

323
    def run_multi(
324
        self,
325
        module_input_dict: Dict[str, Any],
326
        callback_manager: Optional[CallbackManager] = None,
327
    ) -> Dict[str, Any]:
328
        """Run the pipeline for multiple roots."""
329
        callback_manager = callback_manager or self.callback_manager
330
        self.set_callback_manager(callback_manager)
331
        with self.callback_manager.as_trace("query"):
332
            with self.callback_manager.event(
333
                CBEventType.QUERY,
334
                payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
335
            ) as query_event:
336
                return self._run_multi(module_input_dict)
337

338
    async def arun(
339
        self,
340
        *args: Any,
341
        return_values_direct: bool = True,
342
        callback_manager: Optional[CallbackManager] = None,
343
        **kwargs: Any,
344
    ) -> Any:
345
        """Run the pipeline."""
346
        # first set callback manager
347
        callback_manager = callback_manager or self.callback_manager
348
        self.set_callback_manager(callback_manager)
349
        with self.callback_manager.as_trace("query"):
350
            try:
351
                query_payload = json.dumps(kwargs)
352
            except TypeError:
353
                query_payload = json.dumps(str(kwargs))
354
            with self.callback_manager.event(
355
                CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload}
356
            ) as query_event:
357
                return await self._arun(
358
                    *args, return_values_direct=return_values_direct, **kwargs
359
                )
360

361
    async def arun_multi(
362
        self,
363
        module_input_dict: Dict[str, Any],
364
        callback_manager: Optional[CallbackManager] = None,
365
    ) -> Dict[str, Any]:
366
        """Run the pipeline for multiple roots."""
367
        callback_manager = callback_manager or self.callback_manager
368
        self.set_callback_manager(callback_manager)
369
        with self.callback_manager.as_trace("query"):
370
            with self.callback_manager.event(
371
                CBEventType.QUERY,
372
                payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
373
            ) as query_event:
374
                return await self._arun_multi(module_input_dict)
375

376
    def _get_root_key_and_kwargs(
377
        self, *args: Any, **kwargs: Any
378
    ) -> Tuple[str, Dict[str, Any]]:
379
        """Get root key and kwargs.
380

381
        This is for `_run`.
382

383
        """
384
        ## run pipeline
385
        ## assume there is only one root - for multiple roots, need to specify `run_multi`
386
        root_keys = self._get_root_keys()
387
        if len(root_keys) != 1:
388
            raise ValueError("Only one root is supported.")
389
        root_key = root_keys[0]
390

391
        root_module = self.module_dict[root_key]
392
        if len(args) > 0:
393
            # if args is specified, validate. only one arg is allowed, and there can only be one free
394
            # input key in the module
395
            if len(args) > 1:
396
                raise ValueError("Only one arg is allowed.")
397
            if len(kwargs) > 0:
398
                raise ValueError("No kwargs allowed if args is specified.")
399
            if len(root_module.free_req_input_keys) != 1:
400
                raise ValueError("Only one free input key is allowed.")
401
            # set kwargs
402
            kwargs[next(iter(root_module.free_req_input_keys))] = args[0]
403
        return root_key, kwargs
404

405
    def _get_single_result_output(
406
        self,
407
        result_outputs: Dict[str, Any],
408
        return_values_direct: bool,
409
    ) -> Any:
410
        """Get result output from a single module.
411

412
        If output dict is a single key, return the value directly
413
        if return_values_direct is True.
414

415
        """
416
        if len(result_outputs) != 1:
417
            raise ValueError("Only one output is supported.")
418

419
        result_output = next(iter(result_outputs.values()))
420
        # return_values_direct: if True, return the value directly
421
        # without the key
422
        # if it's a dict with one key, return the value
423
        if (
424
            isinstance(result_output, dict)
425
            and len(result_output) == 1
426
            and return_values_direct
427
        ):
428
            return next(iter(result_output.values()))
429
        else:
430
            return result_output
431

432
    def _run(self, *args: Any, return_values_direct: bool = True, **kwargs: Any) -> Any:
433
        """Run the pipeline.
434

435
        Assume that there is a single root module and a single output module.
436

437
        For multi-input and multi-outputs, please see `run_multi`.
438

439
        """
440
        root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
441
        # call run_multi with one root key
442
        result_outputs = self._run_multi({root_key: kwargs})
443
        return self._get_single_result_output(result_outputs, return_values_direct)
444

445
    async def _arun(
446
        self, *args: Any, return_values_direct: bool = True, **kwargs: Any
447
    ) -> Any:
448
        """Run the pipeline.
449

450
        Assume that there is a single root module and a single output module.
451

452
        For multi-input and multi-outputs, please see `run_multi`.
453

454
        """
455
        root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
456
        # call run_multi with one root key
457
        result_outputs = await self._arun_multi({root_key: kwargs})
458
        return self._get_single_result_output(result_outputs, return_values_direct)
459

460
    def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None:
461
        root_keys = self._get_root_keys()
462
        # if root keys don't match up with kwargs keys, raise error
463
        if set(root_keys) != set(module_input_dict.keys()):
464
            raise ValueError(
465
                "Expected root keys do not match up with input keys.\n"
466
                f"Expected root keys: {root_keys}\n"
467
                f"Input keys: {module_input_dict.keys()}\n"
468
            )
469

470
    def _process_component_output(
471
        self,
472
        queue: List[str],
473
        output_dict: Dict[str, Any],
474
        module_key: str,
475
        all_module_inputs: Dict[str, Dict[str, Any]],
476
        result_outputs: Dict[str, Any],
477
    ) -> List[str]:
478
        """Process component output."""
479
        new_queue = queue.copy()
480
        # if there's no more edges, add result to output
481
        if module_key in self._get_leaf_keys():
482
            result_outputs[module_key] = output_dict
483
        else:
484
            edge_list = list(self.dag.edges(module_key, data=True))
485
            # everything not in conditional_edge_list is regular
486
            for _, dest, attr in edge_list:
487
                output = get_output(attr.get("src_key"), output_dict)
488

489
                # if input_fn is not None, use it to modify the input
490
                if attr["input_fn"] is not None:
491
                    dest_output = attr["input_fn"](output)
492
                else:
493
                    dest_output = output
494

495
                add_edge = True
496
                if attr["condition_fn"] is not None:
497
                    conditional_val = attr["condition_fn"](output)
498
                    if not conditional_val:
499
                        add_edge = False
500

501
                if add_edge:
502
                    add_output_to_module_inputs(
503
                        attr.get("dest_key"),
504
                        dest_output,
505
                        self.module_dict[dest],
506
                        all_module_inputs[dest],
507
                    )
508
                else:
509
                    # remove dest from queue
510
                    new_queue.remove(dest)
511

512
        return new_queue
513

514
    def _run_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
515
        """Run the pipeline for multiple roots.
516

517
        kwargs is in the form of module_dict -> input_dict
518
        input_dict is in the form of input_key -> input
519

520
        """
521
        self._validate_inputs(module_input_dict)
522
        queue = list(networkx.topological_sort(self.dag))
523

524
        # module_deps_inputs is a dict to collect inputs for a module
525
        # mapping of module_key -> dict of input_key -> input
526
        # initialize with blank dict for every module key
527
        # the input dict of each module key will be populated as the upstream modules are run
528
        all_module_inputs: Dict[str, Dict[str, Any]] = {
529
            module_key: {} for module_key in self.module_dict
530
        }
531
        result_outputs: Dict[str, Any] = {}
532

533
        # add root inputs to all_module_inputs
534
        for module_key, module_input in module_input_dict.items():
535
            all_module_inputs[module_key] = module_input
536

537
        while len(queue) > 0:
538
            module_key = queue.pop(0)
539
            module = self.module_dict[module_key]
540
            module_input = all_module_inputs[module_key]
541

542
            if self.verbose:
543
                print_debug_input(module_key, module_input)
544
            output_dict = module.run_component(**module_input)
545

546
            # get new nodes and is_leaf
547
            queue = self._process_component_output(
548
                queue, output_dict, module_key, all_module_inputs, result_outputs
549
            )
550

551
        return result_outputs
552

553
    async def _arun_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
554
        """Run the pipeline for multiple roots.
555

556
        kwargs is in the form of module_dict -> input_dict
557
        input_dict is in the form of input_key -> input
558

559
        """
560
        self._validate_inputs(module_input_dict)
561
        queue = list(networkx.topological_sort(self.dag))
562

563
        # module_deps_inputs is a dict to collect inputs for a module
564
        # mapping of module_key -> dict of input_key -> input
565
        # initialize with blank dict for every module key
566
        # the input dict of each module key will be populated as the upstream modules are run
567
        all_module_inputs: Dict[str, Dict[str, Any]] = {
568
            module_key: {} for module_key in self.module_dict
569
        }
570
        result_outputs: Dict[str, Any] = {}
571

572
        # add root inputs to all_module_inputs
573
        for module_key, module_input in module_input_dict.items():
574
            all_module_inputs[module_key] = module_input
575

576
        while len(queue) > 0:
577
            popped_indices = set()
578
            popped_nodes = []
579
            # get subset of nodes who don't have ancestors also in the queue
580
            # these are tasks that are parallelizable
581
            for i, module_key in enumerate(queue):
582
                module_ancestors = networkx.ancestors(self.dag, module_key)
583
                if len(set(module_ancestors).intersection(queue)) == 0:
584
                    popped_indices.add(i)
585
                    popped_nodes.append(module_key)
586

587
            # update queue
588
            queue = [
589
                module_key
590
                for i, module_key in enumerate(queue)
591
                if i not in popped_indices
592
            ]
593

594
            if self.verbose:
595
                print_debug_input_multi(
596
                    popped_nodes,
597
                    [all_module_inputs[module_key] for module_key in popped_nodes],
598
                )
599

600
            # create tasks from popped nodes
601
            tasks = []
602
            for module_key in popped_nodes:
603
                module = self.module_dict[module_key]
604
                module_input = all_module_inputs[module_key]
605
                tasks.append(module.arun_component(**module_input))
606

607
            # run tasks
608
            output_dicts = await run_jobs(
609
                tasks, show_progress=self.show_progress, workers=self.num_workers
610
            )
611

612
            for output_dict, module_key in zip(output_dicts, popped_nodes):
613
                # get new nodes and is_leaf
614
                queue = self._process_component_output(
615
                    queue, output_dict, module_key, all_module_inputs, result_outputs
616
                )
617

618
        return result_outputs
619

620
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
621
        """Validate component inputs during run_component."""
622
        raise NotImplementedError
623

624
    def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
625
        """Validate component inputs."""
626
        return input
627

628
    def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
629
        raise NotImplementedError
630

631
    def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
632
        """Validate component outputs."""
633
        # NOTE: we override this to do nothing
634
        return output
635

636
    def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
637
        """Run component."""
638
        return self.run(return_values_direct=False, **kwargs)
639

640
    async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
641
        """Run component."""
642
        return await self.arun(return_values_direct=False, **kwargs)
643

644
    @property
645
    def input_keys(self) -> InputKeys:
646
        """Input keys."""
647
        # get input key of first module
648
        root_keys = self._get_root_keys()
649
        if len(root_keys) != 1:
650
            raise ValueError("Only one root is supported.")
651
        root_module = self.module_dict[root_keys[0]]
652
        return root_module.input_keys
653

654
    @property
655
    def output_keys(self) -> OutputKeys:
656
        """Output keys."""
657
        # get output key of last module
658
        leaf_keys = self._get_leaf_keys()
659
        if len(leaf_keys) != 1:
660
            raise ValueError("Only one leaf is supported.")
661
        leaf_module = self.module_dict[leaf_keys[0]]
662
        return leaf_module.output_keys
663

664
    @property
665
    def sub_query_components(self) -> List[QueryComponent]:
666
        """Sub query components."""
667
        return list(self.module_dict.values())
668

669
    @property
670
    def clean_dag(self) -> networkx.DiGraph:
671
        """Clean dag."""
672
        return clean_graph_attributes_copy(self.dag)
673

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.