llama-index
171 строка · 5.4 Кб
1"""Node parser interface."""
2
3import asyncio4from abc import abstractmethod5from copy import deepcopy6from typing import Any, Dict, List, Optional, Sequence, cast7
8from typing_extensions import Self9
10from llama_index.legacy.bridge.pydantic import Field11from llama_index.legacy.schema import (12BaseNode,13MetadataMode,14TextNode,15TransformComponent,16)
17
18DEFAULT_NODE_TEXT_TEMPLATE = """\19[Excerpt from document]\n{metadata_str}\n\
20Excerpt:\n-----\n{content}\n-----\n"""
21
22
23class BaseExtractor(TransformComponent):24"""Metadata extractor."""25
26is_text_node_only: bool = True27
28show_progress: bool = Field(default=True, description="Whether to show progress.")29
30metadata_mode: MetadataMode = Field(31default=MetadataMode.ALL, description="Metadata mode to use when reading nodes."32)33
34node_text_template: str = Field(35default=DEFAULT_NODE_TEXT_TEMPLATE,36description="Template to represent how node text is mixed with metadata text.",37)38disable_template_rewrite: bool = Field(39default=False, description="Disable the node template rewrite."40)41
42in_place: bool = Field(43default=True, description="Whether to process nodes in place."44)45
46num_workers: int = Field(47default=4,48description="Number of workers to use for concurrent async processing.",49)50
51@classmethod52def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore53if isinstance(kwargs, dict):54data.update(kwargs)55
56data.pop("class_name", None)57
58llm_predictor = data.get("llm_predictor", None)59if llm_predictor:60from llama_index.legacy.llm_predictor.loading import load_predictor61
62llm_predictor = load_predictor(llm_predictor)63data["llm_predictor"] = llm_predictor64
65llm = data.get("llm", None)66if llm:67from llama_index.legacy.llms.loading import load_llm68
69llm = load_llm(llm)70data["llm"] = llm71
72return cls(**data)73
74@classmethod75def class_name(cls) -> str:76"""Get class name."""77return "MetadataExtractor"78
79@abstractmethod80async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:81"""Extracts metadata for a sequence of nodes, returning a list of82metadata dictionaries corresponding to each node.
83
84Args:
85nodes (Sequence[Document]): nodes to extract metadata from
86
87"""
88
89def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]:90"""Extracts metadata for a sequence of nodes, returning a list of91metadata dictionaries corresponding to each node.
92
93Args:
94nodes (Sequence[Document]): nodes to extract metadata from
95
96"""
97return asyncio.run(self.aextract(nodes))98
99async def aprocess_nodes(100self,101nodes: List[BaseNode],102excluded_embed_metadata_keys: Optional[List[str]] = None,103excluded_llm_metadata_keys: Optional[List[str]] = None,104**kwargs: Any,105) -> List[BaseNode]:106"""Post process nodes parsed from documents.107
108Allows extractors to be chained.
109
110Args:
111nodes (List[BaseNode]): nodes to post-process
112excluded_embed_metadata_keys (Optional[List[str]]):
113keys to exclude from embed metadata
114excluded_llm_metadata_keys (Optional[List[str]]):
115keys to exclude from llm metadata
116"""
117if self.in_place:118new_nodes = nodes119else:120new_nodes = [deepcopy(node) for node in nodes]121
122cur_metadata_list = await self.aextract(new_nodes)123for idx, node in enumerate(new_nodes):124node.metadata.update(cur_metadata_list[idx])125
126for idx, node in enumerate(new_nodes):127if excluded_embed_metadata_keys is not None:128node.excluded_embed_metadata_keys.extend(excluded_embed_metadata_keys)129if excluded_llm_metadata_keys is not None:130node.excluded_llm_metadata_keys.extend(excluded_llm_metadata_keys)131if not self.disable_template_rewrite:132if isinstance(node, TextNode):133cast(TextNode, node).text_template = self.node_text_template134
135return new_nodes136
137def process_nodes(138self,139nodes: List[BaseNode],140excluded_embed_metadata_keys: Optional[List[str]] = None,141excluded_llm_metadata_keys: Optional[List[str]] = None,142**kwargs: Any,143) -> List[BaseNode]:144return asyncio.run(145self.aprocess_nodes(146nodes,147excluded_embed_metadata_keys=excluded_embed_metadata_keys,148excluded_llm_metadata_keys=excluded_llm_metadata_keys,149**kwargs,150)151)152
153def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:154"""Post process nodes parsed from documents.155
156Allows extractors to be chained.
157
158Args:
159nodes (List[BaseNode]): nodes to post-process
160"""
161return self.process_nodes(nodes, **kwargs)162
163async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:164"""Post process nodes parsed from documents.165
166Allows extractors to be chained.
167
168Args:
169nodes (List[BaseNode]): nodes to post-process
170"""
171return await self.aprocess_nodes(nodes, **kwargs)172