llama-index

Форк
0
97 строк · 3.0 Кб
1
from typing import (
2
    TYPE_CHECKING,
3
    Any,
4
    Dict,
5
    Iterable,
6
    List,
7
    Optional,
8
    Sequence,
9
    Type,
10
    cast,
11
)
12

13
if TYPE_CHECKING:
14
    from marvin import ai_model
15

16
from llama_index.legacy.bridge.pydantic import BaseModel, Field
17
from llama_index.legacy.extractors.interface import BaseExtractor
18
from llama_index.legacy.schema import BaseNode, TextNode
19
from llama_index.legacy.utils import get_tqdm_iterable
20

21

22
class MarvinMetadataExtractor(BaseExtractor):
23
    # Forward reference to handle circular imports
24
    marvin_model: Type["ai_model"] = Field(
25
        description="The Marvin model to use for extracting custom metadata"
26
    )
27
    llm_model_string: Optional[str] = Field(
28
        description="The LLM model string to use for extracting custom metadata"
29
    )
30

31
    """Metadata extractor for custom metadata using Marvin.
32
    Node-level extractor. Extracts
33
    `marvin_metadata` metadata field.
34
    Args:
35
        marvin_model: Marvin model to use for extracting metadata
36
        llm_model_string: (optional) LLM model string to use for extracting metadata
37
    Usage:
38
        #create extractor list
39
        extractors = [
40
            TitleExtractor(nodes=1, llm=llm),
41
            MarvinMetadataExtractor(marvin_model=YourMarvinMetadataModel),
42
        ]
43

44
        #create node parser to parse nodes from document
45
        node_parser = SentenceSplitter(
46
            text_splitter=text_splitter
47
        )
48

49
        #use node_parser to get nodes from documents
50
        from llama_index.legacy.ingestion import run_transformations
51
        nodes = run_transformations(documents, [node_parser] + extractors)
52
        print(nodes)
53
    """
54

55
    def __init__(
56
        self,
57
        marvin_model: Type[BaseModel],
58
        llm_model_string: Optional[str] = None,
59
        **kwargs: Any,
60
    ) -> None:
61
        """Init params."""
62
        import marvin
63
        from marvin import ai_model
64

65
        if not issubclass(marvin_model, ai_model):
66
            raise ValueError("marvin_model must be a subclass of ai_model")
67

68
        if llm_model_string:
69
            marvin.settings.llm_model = llm_model_string
70

71
        super().__init__(
72
            marvin_model=marvin_model, llm_model_string=llm_model_string, **kwargs
73
        )
74

75
    @classmethod
76
    def class_name(cls) -> str:
77
        return "MarvinEntityExtractor"
78

79
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
80
        from marvin import ai_model
81

82
        ai_model = cast(ai_model, self.marvin_model)
83
        metadata_list: List[Dict] = []
84

85
        nodes_queue: Iterable[BaseNode] = get_tqdm_iterable(
86
            nodes, self.show_progress, "Extracting marvin metadata"
87
        )
88
        for node in nodes_queue:
89
            if self.is_text_node_only and not isinstance(node, TextNode):
90
                metadata_list.append({})
91
                continue
92

93
            # TODO: Does marvin support async?
94
            metadata = ai_model(node.get_content())
95

96
            metadata_list.append({"marvin_metadata": metadata.dict()})
97
        return metadata_list
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.