llama-index
632 строки · 21.4 Кб
1"""
2Metadata extractors for nodes.
3Currently, only `TextNode` is supported.
4
5Supported metadata:
6Node-level:
7- `SummaryExtractor`: Summary of each node, and pre and post nodes
8- `QuestionsAnsweredExtractor`: Questions that the node can answer
9- `KeywordsExtractor`: Keywords that uniquely identify the node
10Document-level:
11- `TitleExtractor`: Document title, possible inferred across multiple nodes
12
13Unimplemented (contributions welcome):
14Subsection:
15- Position of node in subsection hierarchy (and associated subtitles)
16- Hierarchically organized summary
17
18The prompts used to generate the metadata are specifically aimed to help
19disambiguate the document or subsection from other similar documents or subsections.
20(similar with contrastive learning)
21"""
22
23from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast24
25from llama_index.legacy.async_utils import DEFAULT_NUM_WORKERS, run_jobs26from llama_index.legacy.bridge.pydantic import Field, PrivateAttr27from llama_index.legacy.extractors.interface import BaseExtractor28from llama_index.legacy.llm_predictor.base import LLMPredictorType29from llama_index.legacy.llms.llm import LLM30from llama_index.legacy.llms.utils import resolve_llm31from llama_index.legacy.prompts import PromptTemplate32from llama_index.legacy.schema import BaseNode, TextNode33from llama_index.legacy.types import BasePydanticProgram34from llama_index.legacy.utils import get_tqdm_iterable35
36DEFAULT_TITLE_NODE_TEMPLATE = """\37Context: {context_str}. Give a title that summarizes all of \
38the unique entities, titles or themes found in the context. Title: """
39
40
41DEFAULT_TITLE_COMBINE_TEMPLATE = """\42{context_str}. Based on the above candidate titles and content, \
43what is the comprehensive title for this document? Title: """
44
45
46class TitleExtractor(BaseExtractor):47"""Title extractor. Useful for long documents. Extracts `document_title`48metadata field.
49
50Args:
51llm (Optional[LLM]): LLM
52nodes (int): number of nodes from front to use for title extraction
53node_template (str): template for node-level title clues extraction
54combine_template (str): template for combining node-level clues into
55a document-level title
56"""
57
58is_text_node_only: bool = False # can work for mixture of text and non-text nodes59llm: LLMPredictorType = Field(description="The LLM to use for generation.")60nodes: int = Field(61default=5,62description="The number of nodes to extract titles from.",63gt=0,64)65node_template: str = Field(66default=DEFAULT_TITLE_NODE_TEMPLATE,67description="The prompt template to extract titles with.",68)69combine_template: str = Field(70default=DEFAULT_TITLE_COMBINE_TEMPLATE,71description="The prompt template to merge titles with.",72)73
74def __init__(75self,76llm: Optional[LLM] = None,77# TODO: llm_predictor arg is deprecated78llm_predictor: Optional[LLMPredictorType] = None,79nodes: int = 5,80node_template: str = DEFAULT_TITLE_NODE_TEMPLATE,81combine_template: str = DEFAULT_TITLE_COMBINE_TEMPLATE,82num_workers: int = DEFAULT_NUM_WORKERS,83**kwargs: Any,84) -> None:85"""Init params."""86if nodes < 1:87raise ValueError("num_nodes must be >= 1")88
89super().__init__(90llm=llm or llm_predictor or resolve_llm("default"),91nodes=nodes,92node_template=node_template,93combine_template=combine_template,94num_workers=num_workers,95**kwargs,96)97
98@classmethod99def class_name(cls) -> str:100return "TitleExtractor"101
102async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:103nodes_by_doc_id = self.separate_nodes_by_ref_id(nodes)104titles_by_doc_id = await self.extract_titles(nodes_by_doc_id)105return [{"document_title": titles_by_doc_id[node.ref_doc_id]} for node in nodes]106
107def filter_nodes(self, nodes: Sequence[BaseNode]) -> List[BaseNode]:108filtered_nodes: List[BaseNode] = []109for node in nodes:110if self.is_text_node_only and not isinstance(node, TextNode):111continue112filtered_nodes.append(node)113return filtered_nodes114
115def separate_nodes_by_ref_id(self, nodes: Sequence[BaseNode]) -> Dict:116separated_items: Dict[Optional[str], List[BaseNode]] = {}117
118for node in nodes:119key = node.ref_doc_id120if key not in separated_items:121separated_items[key] = []122
123if len(separated_items[key]) < self.nodes:124separated_items[key].append(node)125
126return separated_items127
128async def extract_titles(self, nodes_by_doc_id: Dict) -> Dict:129titles_by_doc_id = {}130for key, nodes in nodes_by_doc_id.items():131title_candidates = await self.get_title_candidates(nodes)132combined_titles = ", ".join(title_candidates)133titles_by_doc_id[key] = await self.llm.apredict(134PromptTemplate(template=self.combine_template),135context_str=combined_titles,136)137return titles_by_doc_id138
139async def get_title_candidates(self, nodes: List[BaseNode]) -> List[str]:140title_jobs = [141self.llm.apredict(142PromptTemplate(template=self.node_template),143context_str=cast(TextNode, node).text,144)145for node in nodes146]147return await run_jobs(148title_jobs, show_progress=self.show_progress, workers=self.num_workers149)150
151
152class KeywordExtractor(BaseExtractor):153"""Keyword extractor. Node-level extractor. Extracts154`excerpt_keywords` metadata field.
155
156Args:
157llm (Optional[LLM]): LLM
158keywords (int): number of keywords to extract
159"""
160
161llm: LLMPredictorType = Field(description="The LLM to use for generation.")162keywords: int = Field(163default=5, description="The number of keywords to extract.", gt=0164)165
166def __init__(167self,168llm: Optional[LLM] = None,169# TODO: llm_predictor arg is deprecated170llm_predictor: Optional[LLMPredictorType] = None,171keywords: int = 5,172num_workers: int = DEFAULT_NUM_WORKERS,173**kwargs: Any,174) -> None:175"""Init params."""176if keywords < 1:177raise ValueError("num_keywords must be >= 1")178
179super().__init__(180llm=llm or llm_predictor or resolve_llm("default"),181keywords=keywords,182num_workers=num_workers,183**kwargs,184)185
186@classmethod187def class_name(cls) -> str:188return "KeywordExtractor"189
190async def _aextract_keywords_from_node(self, node: BaseNode) -> Dict[str, str]:191"""Extract keywords from a node and return it's metadata dict."""192if self.is_text_node_only and not isinstance(node, TextNode):193return {}194
195# TODO: figure out a good way to allow users to customize keyword template196context_str = node.get_content(metadata_mode=self.metadata_mode)197keywords = await self.llm.apredict(198PromptTemplate(199template=f"""\200{{context_str}}. Give {self.keywords} unique keywords for this \201document. Format as comma separated. Keywords: """
202),203context_str=context_str,204)205
206return {"excerpt_keywords": keywords.strip()}207
208async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:209keyword_jobs = []210for node in nodes:211keyword_jobs.append(self._aextract_keywords_from_node(node))212
213metadata_list: List[Dict] = await run_jobs(214keyword_jobs, show_progress=self.show_progress, workers=self.num_workers215)216
217return metadata_list218
219
220DEFAULT_QUESTION_GEN_TMPL = """\221Here is the context:
222{context_str}
223
224Given the contextual information, \
225generate {num_questions} questions this context can provide \
226specific answers to which are unlikely to be found elsewhere.
227
228Higher-level summaries of surrounding context may be provided \
229as well. Try using these summaries to generate better questions \
230that this context can answer.
231
232"""
233
234
235class QuestionsAnsweredExtractor(BaseExtractor):236"""237Questions answered extractor. Node-level extractor.
238Extracts `questions_this_excerpt_can_answer` metadata field.
239
240Args:
241llm (Optional[LLM]): LLM
242questions (int): number of questions to extract
243prompt_template (str): template for question extraction,
244embedding_only (bool): whether to use embedding only
245"""
246
247llm: LLMPredictorType = Field(description="The LLM to use for generation.")248questions: int = Field(249default=5,250description="The number of questions to generate.",251gt=0,252)253prompt_template: str = Field(254default=DEFAULT_QUESTION_GEN_TMPL,255description="Prompt template to use when generating questions.",256)257embedding_only: bool = Field(258default=True, description="Whether to use metadata for emebddings only."259)260
261def __init__(262self,263llm: Optional[LLM] = None,264# TODO: llm_predictor arg is deprecated265llm_predictor: Optional[LLMPredictorType] = None,266questions: int = 5,267prompt_template: str = DEFAULT_QUESTION_GEN_TMPL,268embedding_only: bool = True,269num_workers: int = DEFAULT_NUM_WORKERS,270**kwargs: Any,271) -> None:272"""Init params."""273if questions < 1:274raise ValueError("questions must be >= 1")275
276super().__init__(277llm=llm or llm_predictor or resolve_llm("default"),278questions=questions,279prompt_template=prompt_template,280embedding_only=embedding_only,281num_workers=num_workers,282**kwargs,283)284
285@classmethod286def class_name(cls) -> str:287return "QuestionsAnsweredExtractor"288
289async def _aextract_questions_from_node(self, node: BaseNode) -> Dict[str, str]:290"""Extract questions from a node and return it's metadata dict."""291if self.is_text_node_only and not isinstance(node, TextNode):292return {}293
294context_str = node.get_content(metadata_mode=self.metadata_mode)295prompt = PromptTemplate(template=self.prompt_template)296questions = await self.llm.apredict(297prompt, num_questions=self.questions, context_str=context_str298)299
300return {"questions_this_excerpt_can_answer": questions.strip()}301
302async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:303questions_jobs = []304for node in nodes:305questions_jobs.append(self._aextract_questions_from_node(node))306
307metadata_list: List[Dict] = await run_jobs(308questions_jobs, show_progress=self.show_progress, workers=self.num_workers309)310
311return metadata_list312
313
314DEFAULT_SUMMARY_EXTRACT_TEMPLATE = """\315Here is the content of the section:
316{context_str}
317
318Summarize the key topics and entities of the section. \
319
320Summary: """
321
322
323class SummaryExtractor(BaseExtractor):324"""325Summary extractor. Node-level extractor with adjacent sharing.
326Extracts `section_summary`, `prev_section_summary`, `next_section_summary`
327metadata fields.
328
329Args:
330llm (Optional[LLM]): LLM
331summaries (List[str]): list of summaries to extract: 'self', 'prev', 'next'
332prompt_template (str): template for summary extraction
333"""
334
335llm: LLMPredictorType = Field(description="The LLM to use for generation.")336summaries: List[str] = Field(337description="List of summaries to extract: 'self', 'prev', 'next'"338)339prompt_template: str = Field(340default=DEFAULT_SUMMARY_EXTRACT_TEMPLATE,341description="Template to use when generating summaries.",342)343
344_self_summary: bool = PrivateAttr()345_prev_summary: bool = PrivateAttr()346_next_summary: bool = PrivateAttr()347
348def __init__(349self,350llm: Optional[LLM] = None,351# TODO: llm_predictor arg is deprecated352llm_predictor: Optional[LLMPredictorType] = None,353summaries: List[str] = ["self"],354prompt_template: str = DEFAULT_SUMMARY_EXTRACT_TEMPLATE,355num_workers: int = DEFAULT_NUM_WORKERS,356**kwargs: Any,357):358# validation359if not all(s in ["self", "prev", "next"] for s in summaries):360raise ValueError("summaries must be one of ['self', 'prev', 'next']")361self._self_summary = "self" in summaries362self._prev_summary = "prev" in summaries363self._next_summary = "next" in summaries364
365super().__init__(366llm=llm or llm_predictor or resolve_llm("default"),367summaries=summaries,368prompt_template=prompt_template,369num_workers=num_workers,370**kwargs,371)372
373@classmethod374def class_name(cls) -> str:375return "SummaryExtractor"376
377async def _agenerate_node_summary(self, node: BaseNode) -> str:378"""Generate a summary for a node."""379if self.is_text_node_only and not isinstance(node, TextNode):380return ""381
382context_str = node.get_content(metadata_mode=self.metadata_mode)383summary = await self.llm.apredict(384PromptTemplate(template=self.prompt_template), context_str=context_str385)386
387return summary.strip()388
389async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:390if not all(isinstance(node, TextNode) for node in nodes):391raise ValueError("Only `TextNode` is allowed for `Summary` extractor")392
393node_summaries_jobs = []394for node in nodes:395node_summaries_jobs.append(self._agenerate_node_summary(node))396
397node_summaries = await run_jobs(398node_summaries_jobs,399show_progress=self.show_progress,400workers=self.num_workers,401)402
403# Extract node-level summary metadata404metadata_list: List[Dict] = [{} for _ in nodes]405for i, metadata in enumerate(metadata_list):406if i > 0 and self._prev_summary and node_summaries[i - 1]:407metadata["prev_section_summary"] = node_summaries[i - 1]408if i < len(nodes) - 1 and self._next_summary and node_summaries[i + 1]:409metadata["next_section_summary"] = node_summaries[i + 1]410if self._self_summary and node_summaries[i]:411metadata["section_summary"] = node_summaries[i]412
413return metadata_list414
415
416DEFAULT_ENTITY_MAP = {417"PER": "persons",418"ORG": "organizations",419"LOC": "locations",420"ANIM": "animals",421"BIO": "biological",422"CEL": "celestial",423"DIS": "diseases",424"EVE": "events",425"FOOD": "foods",426"INST": "instruments",427"MEDIA": "media",428"PLANT": "plants",429"MYTH": "mythological",430"TIME": "times",431"VEHI": "vehicles",432}
433
434DEFAULT_ENTITY_MODEL = "tomaarsen/span-marker-mbert-base-multinerd"435
436
437class EntityExtractor(BaseExtractor):438"""439Entity extractor. Extracts `entities` into a metadata field using a default model
440`tomaarsen/span-marker-mbert-base-multinerd` and the SpanMarker library.
441
442Install SpanMarker with `pip install span-marker`.
443"""
444
445model_name: str = Field(446default=DEFAULT_ENTITY_MODEL,447description="The model name of the SpanMarker model to use.",448)449prediction_threshold: float = Field(450default=0.5,451description="The confidence threshold for accepting predictions.",452gte=0.0,453lte=1.0,454)455span_joiner: str = Field(456default=" ", description="The separator between entity names."457)458label_entities: bool = Field(459default=False, description="Include entity class labels or not."460)461device: Optional[str] = Field(462default=None, description="Device to run model on, i.e. 'cuda', 'cpu'"463)464entity_map: Dict[str, str] = Field(465default_factory=dict,466description="Mapping of entity class names to usable names.",467)468
469_tokenizer: Callable = PrivateAttr()470_model: Any = PrivateAttr()471
472def __init__(473self,474model_name: str = DEFAULT_ENTITY_MODEL,475prediction_threshold: float = 0.5,476span_joiner: str = " ",477label_entities: bool = False,478device: Optional[str] = None,479entity_map: Optional[Dict[str, str]] = None,480tokenizer: Optional[Callable[[str], List[str]]] = None,481**kwargs: Any,482):483"""484Entity extractor for extracting entities from text and inserting
485into node metadata.
486
487Args:
488model_name (str):
489Name of the SpanMarker model to use.
490prediction_threshold (float):
491Minimum prediction threshold for entities. Defaults to 0.5.
492span_joiner (str):
493String to join spans with. Defaults to " ".
494label_entities (bool):
495Whether to label entities with their type. Setting to true can be
496slightly error prone, but can be useful for downstream tasks.
497Defaults to False.
498device (Optional[str]):
499Device to use for SpanMarker model, i.e. "cpu" or "cuda".
500Loads onto "cpu" by default.
501entity_map (Optional[Dict[str, str]]):
502Mapping from entity class name to label.
503tokenizer (Optional[Callable[[str], List[str]]]):
504Tokenizer to use for splitting text into words.
505Defaults to NLTK word_tokenize.
506"""
507try:508from span_marker import SpanMarkerModel509except ImportError:510raise ImportError(511"SpanMarker is not installed. Install with `pip install span-marker`."512)513
514try:515from nltk.tokenize import word_tokenize516except ImportError:517raise ImportError("NLTK is not installed. Install with `pip install nltk`.")518
519self._model = SpanMarkerModel.from_pretrained(model_name)520if device is not None:521self._model = self._model.to(device)522
523self._tokenizer = tokenizer or word_tokenize524
525base_entity_map = DEFAULT_ENTITY_MAP526if entity_map is not None:527base_entity_map.update(entity_map)528
529super().__init__(530model_name=model_name,531prediction_threshold=prediction_threshold,532span_joiner=span_joiner,533label_entities=label_entities,534device=device,535entity_map=base_entity_map,536**kwargs,537)538
539@classmethod540def class_name(cls) -> str:541return "EntityExtractor"542
543async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:544# Extract node-level entity metadata545metadata_list: List[Dict] = [{} for _ in nodes]546metadata_queue: Iterable[int] = get_tqdm_iterable(547range(len(nodes)), self.show_progress, "Extracting entities"548)549
550for i in metadata_queue:551metadata = metadata_list[i]552node_text = nodes[i].get_content(metadata_mode=self.metadata_mode)553words = self._tokenizer(node_text)554spans = self._model.predict(words)555for span in spans:556if span["score"] > self.prediction_threshold:557ent_label = self.entity_map.get(span["label"], span["label"])558metadata_label = ent_label if self.label_entities else "entities"559
560if metadata_label not in metadata:561metadata[metadata_label] = set()562
563metadata[metadata_label].add(self.span_joiner.join(span["span"]))564
565# convert metadata from set to list566for metadata in metadata_list:567for key, val in metadata.items():568metadata[key] = list(val)569
570return metadata_list571
572
573DEFAULT_EXTRACT_TEMPLATE_STR = """\574Here is the content of the section:
575----------------
576{context_str}
577----------------
578Given the contextual information, extract out a {class_name} object.\
579"""
580
581
582class PydanticProgramExtractor(BaseExtractor):583"""Pydantic program extractor.584
585Uses an LLM to extract out a Pydantic object. Return attributes of that object
586in a dictionary.
587
588"""
589
590program: BasePydanticProgram = Field(591..., description="Pydantic program to extract."592)593input_key: str = Field(594default="input",595description=(596"Key to use as input to the program (the program "597"template string must expose this key)."598),599)600extract_template_str: str = Field(601default=DEFAULT_EXTRACT_TEMPLATE_STR,602description="Template to use for extraction.",603)604
605@classmethod606def class_name(cls) -> str:607return "PydanticModelExtractor"608
609async def _acall_program(self, node: BaseNode) -> Dict[str, Any]:610"""Call the program on a node."""611if self.is_text_node_only and not isinstance(node, TextNode):612return {}613
614extract_str = self.extract_template_str.format(615context_str=node.get_content(metadata_mode=self.metadata_mode),616class_name=self.program.output_cls.__name__,617)618
619ret_object = await self.program.acall(**{self.input_key: extract_str})620return ret_object.dict()621
622async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:623"""Extract pydantic program."""624program_jobs = []625for node in nodes:626program_jobs.append(self._acall_program(node))627
628metadata_list: List[Dict] = await run_jobs(629program_jobs, show_progress=self.show_progress, workers=self.num_workers630)631
632return metadata_list633