llama-index

Форк
0
171 строка · 5.4 Кб
1
"""Node parser interface."""
2

3
import asyncio
4
from abc import abstractmethod
5
from copy import deepcopy
6
from typing import Any, Dict, List, Optional, Sequence, cast
7

8
from typing_extensions import Self
9

10
from llama_index.legacy.bridge.pydantic import Field
11
from llama_index.legacy.schema import (
12
    BaseNode,
13
    MetadataMode,
14
    TextNode,
15
    TransformComponent,
16
)
17

18
DEFAULT_NODE_TEXT_TEMPLATE = """\
19
[Excerpt from document]\n{metadata_str}\n\
20
Excerpt:\n-----\n{content}\n-----\n"""
21

22

23
class BaseExtractor(TransformComponent):
24
    """Metadata extractor."""
25

26
    is_text_node_only: bool = True
27

28
    show_progress: bool = Field(default=True, description="Whether to show progress.")
29

30
    metadata_mode: MetadataMode = Field(
31
        default=MetadataMode.ALL, description="Metadata mode to use when reading nodes."
32
    )
33

34
    node_text_template: str = Field(
35
        default=DEFAULT_NODE_TEXT_TEMPLATE,
36
        description="Template to represent how node text is mixed with metadata text.",
37
    )
38
    disable_template_rewrite: bool = Field(
39
        default=False, description="Disable the node template rewrite."
40
    )
41

42
    in_place: bool = Field(
43
        default=True, description="Whether to process nodes in place."
44
    )
45

46
    num_workers: int = Field(
47
        default=4,
48
        description="Number of workers to use for concurrent async processing.",
49
    )
50

51
    @classmethod
52
    def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self:  # type: ignore
53
        if isinstance(kwargs, dict):
54
            data.update(kwargs)
55

56
        data.pop("class_name", None)
57

58
        llm_predictor = data.get("llm_predictor", None)
59
        if llm_predictor:
60
            from llama_index.legacy.llm_predictor.loading import load_predictor
61

62
            llm_predictor = load_predictor(llm_predictor)
63
            data["llm_predictor"] = llm_predictor
64

65
        llm = data.get("llm", None)
66
        if llm:
67
            from llama_index.legacy.llms.loading import load_llm
68

69
            llm = load_llm(llm)
70
            data["llm"] = llm
71

72
        return cls(**data)
73

74
    @classmethod
75
    def class_name(cls) -> str:
76
        """Get class name."""
77
        return "MetadataExtractor"
78

79
    @abstractmethod
80
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
81
        """Extracts metadata for a sequence of nodes, returning a list of
82
        metadata dictionaries corresponding to each node.
83

84
        Args:
85
            nodes (Sequence[Document]): nodes to extract metadata from
86

87
        """
88

89
    def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
90
        """Extracts metadata for a sequence of nodes, returning a list of
91
        metadata dictionaries corresponding to each node.
92

93
        Args:
94
            nodes (Sequence[Document]): nodes to extract metadata from
95

96
        """
97
        return asyncio.run(self.aextract(nodes))
98

99
    async def aprocess_nodes(
100
        self,
101
        nodes: List[BaseNode],
102
        excluded_embed_metadata_keys: Optional[List[str]] = None,
103
        excluded_llm_metadata_keys: Optional[List[str]] = None,
104
        **kwargs: Any,
105
    ) -> List[BaseNode]:
106
        """Post process nodes parsed from documents.
107

108
        Allows extractors to be chained.
109

110
        Args:
111
            nodes (List[BaseNode]): nodes to post-process
112
            excluded_embed_metadata_keys (Optional[List[str]]):
113
                keys to exclude from embed metadata
114
            excluded_llm_metadata_keys (Optional[List[str]]):
115
                keys to exclude from llm metadata
116
        """
117
        if self.in_place:
118
            new_nodes = nodes
119
        else:
120
            new_nodes = [deepcopy(node) for node in nodes]
121

122
        cur_metadata_list = await self.aextract(new_nodes)
123
        for idx, node in enumerate(new_nodes):
124
            node.metadata.update(cur_metadata_list[idx])
125

126
        for idx, node in enumerate(new_nodes):
127
            if excluded_embed_metadata_keys is not None:
128
                node.excluded_embed_metadata_keys.extend(excluded_embed_metadata_keys)
129
            if excluded_llm_metadata_keys is not None:
130
                node.excluded_llm_metadata_keys.extend(excluded_llm_metadata_keys)
131
            if not self.disable_template_rewrite:
132
                if isinstance(node, TextNode):
133
                    cast(TextNode, node).text_template = self.node_text_template
134

135
        return new_nodes
136

137
    def process_nodes(
138
        self,
139
        nodes: List[BaseNode],
140
        excluded_embed_metadata_keys: Optional[List[str]] = None,
141
        excluded_llm_metadata_keys: Optional[List[str]] = None,
142
        **kwargs: Any,
143
    ) -> List[BaseNode]:
144
        return asyncio.run(
145
            self.aprocess_nodes(
146
                nodes,
147
                excluded_embed_metadata_keys=excluded_embed_metadata_keys,
148
                excluded_llm_metadata_keys=excluded_llm_metadata_keys,
149
                **kwargs,
150
            )
151
        )
152

153
    def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
154
        """Post process nodes parsed from documents.
155

156
        Allows extractors to be chained.
157

158
        Args:
159
            nodes (List[BaseNode]): nodes to post-process
160
        """
161
        return self.process_nodes(nodes, **kwargs)
162

163
    async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
164
        """Post process nodes parsed from documents.
165

166
        Allows extractors to be chained.
167

168
        Args:
169
            nodes (List[BaseNode]): nodes to post-process
170
        """
171
        return await self.aprocess_nodes(nodes, **kwargs)
172

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.