llama-index
97 строк · 3.0 Кб
1from typing import (2TYPE_CHECKING,3Any,4Dict,5Iterable,6List,7Optional,8Sequence,9Type,10cast,11)
12
13if TYPE_CHECKING:14from marvin import ai_model15
16from llama_index.legacy.bridge.pydantic import BaseModel, Field17from llama_index.legacy.extractors.interface import BaseExtractor18from llama_index.legacy.schema import BaseNode, TextNode19from llama_index.legacy.utils import get_tqdm_iterable20
21
22class MarvinMetadataExtractor(BaseExtractor):23# Forward reference to handle circular imports24marvin_model: Type["ai_model"] = Field(25description="The Marvin model to use for extracting custom metadata"26)27llm_model_string: Optional[str] = Field(28description="The LLM model string to use for extracting custom metadata"29)30
31"""Metadata extractor for custom metadata using Marvin.32Node-level extractor. Extracts
33`marvin_metadata` metadata field.
34Args:
35marvin_model: Marvin model to use for extracting metadata
36llm_model_string: (optional) LLM model string to use for extracting metadata
37Usage:
38#create extractor list
39extractors = [
40TitleExtractor(nodes=1, llm=llm),
41MarvinMetadataExtractor(marvin_model=YourMarvinMetadataModel),
42]
43
44#create node parser to parse nodes from document
45node_parser = SentenceSplitter(
46text_splitter=text_splitter
47)
48
49#use node_parser to get nodes from documents
50from llama_index.legacy.ingestion import run_transformations
51nodes = run_transformations(documents, [node_parser] + extractors)
52print(nodes)
53"""
54
55def __init__(56self,57marvin_model: Type[BaseModel],58llm_model_string: Optional[str] = None,59**kwargs: Any,60) -> None:61"""Init params."""62import marvin63from marvin import ai_model64
65if not issubclass(marvin_model, ai_model):66raise ValueError("marvin_model must be a subclass of ai_model")67
68if llm_model_string:69marvin.settings.llm_model = llm_model_string70
71super().__init__(72marvin_model=marvin_model, llm_model_string=llm_model_string, **kwargs73)74
75@classmethod76def class_name(cls) -> str:77return "MarvinEntityExtractor"78
79async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:80from marvin import ai_model81
82ai_model = cast(ai_model, self.marvin_model)83metadata_list: List[Dict] = []84
85nodes_queue: Iterable[BaseNode] = get_tqdm_iterable(86nodes, self.show_progress, "Extracting marvin metadata"87)88for node in nodes_queue:89if self.is_text_node_only and not isinstance(node, TextNode):90metadata_list.append({})91continue92
93# TODO: Does marvin support async?94metadata = ai_model(node.get_content())95
96metadata_list.append({"marvin_metadata": metadata.dict()})97return metadata_list98