llama-index

Форк
0
632 строки · 21.4 Кб
1
"""
2
Metadata extractors for nodes.
3
Currently, only `TextNode` is supported.
4

5
Supported metadata:
6
Node-level:
7
    - `SummaryExtractor`: Summary of each node, and pre and post nodes
8
    - `QuestionsAnsweredExtractor`: Questions that the node can answer
9
    - `KeywordsExtractor`: Keywords that uniquely identify the node
10
Document-level:
11
    - `TitleExtractor`: Document title, possible inferred across multiple nodes
12

13
Unimplemented (contributions welcome):
14
Subsection:
15
    - Position of node in subsection hierarchy (and associated subtitles)
16
    - Hierarchically organized summary
17

18
The prompts used to generate the metadata are specifically aimed to help
19
disambiguate the document or subsection from other similar documents or subsections.
20
(similar with contrastive learning)
21
"""
22

23
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
24

25
from llama_index.legacy.async_utils import DEFAULT_NUM_WORKERS, run_jobs
26
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
27
from llama_index.legacy.extractors.interface import BaseExtractor
28
from llama_index.legacy.llm_predictor.base import LLMPredictorType
29
from llama_index.legacy.llms.llm import LLM
30
from llama_index.legacy.llms.utils import resolve_llm
31
from llama_index.legacy.prompts import PromptTemplate
32
from llama_index.legacy.schema import BaseNode, TextNode
33
from llama_index.legacy.types import BasePydanticProgram
34
from llama_index.legacy.utils import get_tqdm_iterable
35

36
DEFAULT_TITLE_NODE_TEMPLATE = """\
37
Context: {context_str}. Give a title that summarizes all of \
38
the unique entities, titles or themes found in the context. Title: """
39

40

41
DEFAULT_TITLE_COMBINE_TEMPLATE = """\
42
{context_str}. Based on the above candidate titles and content, \
43
what is the comprehensive title for this document? Title: """
44

45

46
class TitleExtractor(BaseExtractor):
47
    """Title extractor. Useful for long documents. Extracts `document_title`
48
    metadata field.
49

50
    Args:
51
        llm (Optional[LLM]): LLM
52
        nodes (int): number of nodes from front to use for title extraction
53
        node_template (str): template for node-level title clues extraction
54
        combine_template (str): template for combining node-level clues into
55
            a document-level title
56
    """
57

58
    is_text_node_only: bool = False  # can work for mixture of text and non-text nodes
59
    llm: LLMPredictorType = Field(description="The LLM to use for generation.")
60
    nodes: int = Field(
61
        default=5,
62
        description="The number of nodes to extract titles from.",
63
        gt=0,
64
    )
65
    node_template: str = Field(
66
        default=DEFAULT_TITLE_NODE_TEMPLATE,
67
        description="The prompt template to extract titles with.",
68
    )
69
    combine_template: str = Field(
70
        default=DEFAULT_TITLE_COMBINE_TEMPLATE,
71
        description="The prompt template to merge titles with.",
72
    )
73

74
    def __init__(
75
        self,
76
        llm: Optional[LLM] = None,
77
        # TODO: llm_predictor arg is deprecated
78
        llm_predictor: Optional[LLMPredictorType] = None,
79
        nodes: int = 5,
80
        node_template: str = DEFAULT_TITLE_NODE_TEMPLATE,
81
        combine_template: str = DEFAULT_TITLE_COMBINE_TEMPLATE,
82
        num_workers: int = DEFAULT_NUM_WORKERS,
83
        **kwargs: Any,
84
    ) -> None:
85
        """Init params."""
86
        if nodes < 1:
87
            raise ValueError("num_nodes must be >= 1")
88

89
        super().__init__(
90
            llm=llm or llm_predictor or resolve_llm("default"),
91
            nodes=nodes,
92
            node_template=node_template,
93
            combine_template=combine_template,
94
            num_workers=num_workers,
95
            **kwargs,
96
        )
97

98
    @classmethod
99
    def class_name(cls) -> str:
100
        return "TitleExtractor"
101

102
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
103
        nodes_by_doc_id = self.separate_nodes_by_ref_id(nodes)
104
        titles_by_doc_id = await self.extract_titles(nodes_by_doc_id)
105
        return [{"document_title": titles_by_doc_id[node.ref_doc_id]} for node in nodes]
106

107
    def filter_nodes(self, nodes: Sequence[BaseNode]) -> List[BaseNode]:
108
        filtered_nodes: List[BaseNode] = []
109
        for node in nodes:
110
            if self.is_text_node_only and not isinstance(node, TextNode):
111
                continue
112
            filtered_nodes.append(node)
113
        return filtered_nodes
114

115
    def separate_nodes_by_ref_id(self, nodes: Sequence[BaseNode]) -> Dict:
116
        separated_items: Dict[Optional[str], List[BaseNode]] = {}
117

118
        for node in nodes:
119
            key = node.ref_doc_id
120
            if key not in separated_items:
121
                separated_items[key] = []
122

123
            if len(separated_items[key]) < self.nodes:
124
                separated_items[key].append(node)
125

126
        return separated_items
127

128
    async def extract_titles(self, nodes_by_doc_id: Dict) -> Dict:
129
        titles_by_doc_id = {}
130
        for key, nodes in nodes_by_doc_id.items():
131
            title_candidates = await self.get_title_candidates(nodes)
132
            combined_titles = ", ".join(title_candidates)
133
            titles_by_doc_id[key] = await self.llm.apredict(
134
                PromptTemplate(template=self.combine_template),
135
                context_str=combined_titles,
136
            )
137
        return titles_by_doc_id
138

139
    async def get_title_candidates(self, nodes: List[BaseNode]) -> List[str]:
140
        title_jobs = [
141
            self.llm.apredict(
142
                PromptTemplate(template=self.node_template),
143
                context_str=cast(TextNode, node).text,
144
            )
145
            for node in nodes
146
        ]
147
        return await run_jobs(
148
            title_jobs, show_progress=self.show_progress, workers=self.num_workers
149
        )
150

151

152
class KeywordExtractor(BaseExtractor):
153
    """Keyword extractor. Node-level extractor. Extracts
154
    `excerpt_keywords` metadata field.
155

156
    Args:
157
        llm (Optional[LLM]): LLM
158
        keywords (int): number of keywords to extract
159
    """
160

161
    llm: LLMPredictorType = Field(description="The LLM to use for generation.")
162
    keywords: int = Field(
163
        default=5, description="The number of keywords to extract.", gt=0
164
    )
165

166
    def __init__(
167
        self,
168
        llm: Optional[LLM] = None,
169
        # TODO: llm_predictor arg is deprecated
170
        llm_predictor: Optional[LLMPredictorType] = None,
171
        keywords: int = 5,
172
        num_workers: int = DEFAULT_NUM_WORKERS,
173
        **kwargs: Any,
174
    ) -> None:
175
        """Init params."""
176
        if keywords < 1:
177
            raise ValueError("num_keywords must be >= 1")
178

179
        super().__init__(
180
            llm=llm or llm_predictor or resolve_llm("default"),
181
            keywords=keywords,
182
            num_workers=num_workers,
183
            **kwargs,
184
        )
185

186
    @classmethod
187
    def class_name(cls) -> str:
188
        return "KeywordExtractor"
189

190
    async def _aextract_keywords_from_node(self, node: BaseNode) -> Dict[str, str]:
191
        """Extract keywords from a node and return it's metadata dict."""
192
        if self.is_text_node_only and not isinstance(node, TextNode):
193
            return {}
194

195
        # TODO: figure out a good way to allow users to customize keyword template
196
        context_str = node.get_content(metadata_mode=self.metadata_mode)
197
        keywords = await self.llm.apredict(
198
            PromptTemplate(
199
                template=f"""\
200
{{context_str}}. Give {self.keywords} unique keywords for this \
201
document. Format as comma separated. Keywords: """
202
            ),
203
            context_str=context_str,
204
        )
205

206
        return {"excerpt_keywords": keywords.strip()}
207

208
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
209
        keyword_jobs = []
210
        for node in nodes:
211
            keyword_jobs.append(self._aextract_keywords_from_node(node))
212

213
        metadata_list: List[Dict] = await run_jobs(
214
            keyword_jobs, show_progress=self.show_progress, workers=self.num_workers
215
        )
216

217
        return metadata_list
218

219

220
DEFAULT_QUESTION_GEN_TMPL = """\
221
Here is the context:
222
{context_str}
223

224
Given the contextual information, \
225
generate {num_questions} questions this context can provide \
226
specific answers to which are unlikely to be found elsewhere.
227

228
Higher-level summaries of surrounding context may be provided \
229
as well. Try using these summaries to generate better questions \
230
that this context can answer.
231

232
"""
233

234

235
class QuestionsAnsweredExtractor(BaseExtractor):
236
    """
237
    Questions answered extractor. Node-level extractor.
238
    Extracts `questions_this_excerpt_can_answer` metadata field.
239

240
    Args:
241
        llm (Optional[LLM]): LLM
242
        questions (int): number of questions to extract
243
        prompt_template (str): template for question extraction,
244
        embedding_only (bool): whether to use embedding only
245
    """
246

247
    llm: LLMPredictorType = Field(description="The LLM to use for generation.")
248
    questions: int = Field(
249
        default=5,
250
        description="The number of questions to generate.",
251
        gt=0,
252
    )
253
    prompt_template: str = Field(
254
        default=DEFAULT_QUESTION_GEN_TMPL,
255
        description="Prompt template to use when generating questions.",
256
    )
257
    embedding_only: bool = Field(
258
        default=True, description="Whether to use metadata for emebddings only."
259
    )
260

261
    def __init__(
262
        self,
263
        llm: Optional[LLM] = None,
264
        # TODO: llm_predictor arg is deprecated
265
        llm_predictor: Optional[LLMPredictorType] = None,
266
        questions: int = 5,
267
        prompt_template: str = DEFAULT_QUESTION_GEN_TMPL,
268
        embedding_only: bool = True,
269
        num_workers: int = DEFAULT_NUM_WORKERS,
270
        **kwargs: Any,
271
    ) -> None:
272
        """Init params."""
273
        if questions < 1:
274
            raise ValueError("questions must be >= 1")
275

276
        super().__init__(
277
            llm=llm or llm_predictor or resolve_llm("default"),
278
            questions=questions,
279
            prompt_template=prompt_template,
280
            embedding_only=embedding_only,
281
            num_workers=num_workers,
282
            **kwargs,
283
        )
284

285
    @classmethod
286
    def class_name(cls) -> str:
287
        return "QuestionsAnsweredExtractor"
288

289
    async def _aextract_questions_from_node(self, node: BaseNode) -> Dict[str, str]:
290
        """Extract questions from a node and return it's metadata dict."""
291
        if self.is_text_node_only and not isinstance(node, TextNode):
292
            return {}
293

294
        context_str = node.get_content(metadata_mode=self.metadata_mode)
295
        prompt = PromptTemplate(template=self.prompt_template)
296
        questions = await self.llm.apredict(
297
            prompt, num_questions=self.questions, context_str=context_str
298
        )
299

300
        return {"questions_this_excerpt_can_answer": questions.strip()}
301

302
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
303
        questions_jobs = []
304
        for node in nodes:
305
            questions_jobs.append(self._aextract_questions_from_node(node))
306

307
        metadata_list: List[Dict] = await run_jobs(
308
            questions_jobs, show_progress=self.show_progress, workers=self.num_workers
309
        )
310

311
        return metadata_list
312

313

314
DEFAULT_SUMMARY_EXTRACT_TEMPLATE = """\
315
Here is the content of the section:
316
{context_str}
317

318
Summarize the key topics and entities of the section. \
319

320
Summary: """
321

322

323
class SummaryExtractor(BaseExtractor):
324
    """
325
    Summary extractor. Node-level extractor with adjacent sharing.
326
    Extracts `section_summary`, `prev_section_summary`, `next_section_summary`
327
    metadata fields.
328

329
    Args:
330
        llm (Optional[LLM]): LLM
331
        summaries (List[str]): list of summaries to extract: 'self', 'prev', 'next'
332
        prompt_template (str): template for summary extraction
333
    """
334

335
    llm: LLMPredictorType = Field(description="The LLM to use for generation.")
336
    summaries: List[str] = Field(
337
        description="List of summaries to extract: 'self', 'prev', 'next'"
338
    )
339
    prompt_template: str = Field(
340
        default=DEFAULT_SUMMARY_EXTRACT_TEMPLATE,
341
        description="Template to use when generating summaries.",
342
    )
343

344
    _self_summary: bool = PrivateAttr()
345
    _prev_summary: bool = PrivateAttr()
346
    _next_summary: bool = PrivateAttr()
347

348
    def __init__(
349
        self,
350
        llm: Optional[LLM] = None,
351
        # TODO: llm_predictor arg is deprecated
352
        llm_predictor: Optional[LLMPredictorType] = None,
353
        summaries: List[str] = ["self"],
354
        prompt_template: str = DEFAULT_SUMMARY_EXTRACT_TEMPLATE,
355
        num_workers: int = DEFAULT_NUM_WORKERS,
356
        **kwargs: Any,
357
    ):
358
        # validation
359
        if not all(s in ["self", "prev", "next"] for s in summaries):
360
            raise ValueError("summaries must be one of ['self', 'prev', 'next']")
361
        self._self_summary = "self" in summaries
362
        self._prev_summary = "prev" in summaries
363
        self._next_summary = "next" in summaries
364

365
        super().__init__(
366
            llm=llm or llm_predictor or resolve_llm("default"),
367
            summaries=summaries,
368
            prompt_template=prompt_template,
369
            num_workers=num_workers,
370
            **kwargs,
371
        )
372

373
    @classmethod
374
    def class_name(cls) -> str:
375
        return "SummaryExtractor"
376

377
    async def _agenerate_node_summary(self, node: BaseNode) -> str:
378
        """Generate a summary for a node."""
379
        if self.is_text_node_only and not isinstance(node, TextNode):
380
            return ""
381

382
        context_str = node.get_content(metadata_mode=self.metadata_mode)
383
        summary = await self.llm.apredict(
384
            PromptTemplate(template=self.prompt_template), context_str=context_str
385
        )
386

387
        return summary.strip()
388

389
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
390
        if not all(isinstance(node, TextNode) for node in nodes):
391
            raise ValueError("Only `TextNode` is allowed for `Summary` extractor")
392

393
        node_summaries_jobs = []
394
        for node in nodes:
395
            node_summaries_jobs.append(self._agenerate_node_summary(node))
396

397
        node_summaries = await run_jobs(
398
            node_summaries_jobs,
399
            show_progress=self.show_progress,
400
            workers=self.num_workers,
401
        )
402

403
        # Extract node-level summary metadata
404
        metadata_list: List[Dict] = [{} for _ in nodes]
405
        for i, metadata in enumerate(metadata_list):
406
            if i > 0 and self._prev_summary and node_summaries[i - 1]:
407
                metadata["prev_section_summary"] = node_summaries[i - 1]
408
            if i < len(nodes) - 1 and self._next_summary and node_summaries[i + 1]:
409
                metadata["next_section_summary"] = node_summaries[i + 1]
410
            if self._self_summary and node_summaries[i]:
411
                metadata["section_summary"] = node_summaries[i]
412

413
        return metadata_list
414

415

416
DEFAULT_ENTITY_MAP = {
417
    "PER": "persons",
418
    "ORG": "organizations",
419
    "LOC": "locations",
420
    "ANIM": "animals",
421
    "BIO": "biological",
422
    "CEL": "celestial",
423
    "DIS": "diseases",
424
    "EVE": "events",
425
    "FOOD": "foods",
426
    "INST": "instruments",
427
    "MEDIA": "media",
428
    "PLANT": "plants",
429
    "MYTH": "mythological",
430
    "TIME": "times",
431
    "VEHI": "vehicles",
432
}
433

434
DEFAULT_ENTITY_MODEL = "tomaarsen/span-marker-mbert-base-multinerd"
435

436

437
class EntityExtractor(BaseExtractor):
438
    """
439
    Entity extractor. Extracts `entities` into a metadata field using a default model
440
    `tomaarsen/span-marker-mbert-base-multinerd` and the SpanMarker library.
441

442
    Install SpanMarker with `pip install span-marker`.
443
    """
444

445
    model_name: str = Field(
446
        default=DEFAULT_ENTITY_MODEL,
447
        description="The model name of the SpanMarker model to use.",
448
    )
449
    prediction_threshold: float = Field(
450
        default=0.5,
451
        description="The confidence threshold for accepting predictions.",
452
        gte=0.0,
453
        lte=1.0,
454
    )
455
    span_joiner: str = Field(
456
        default=" ", description="The separator between entity names."
457
    )
458
    label_entities: bool = Field(
459
        default=False, description="Include entity class labels or not."
460
    )
461
    device: Optional[str] = Field(
462
        default=None, description="Device to run model on, i.e. 'cuda', 'cpu'"
463
    )
464
    entity_map: Dict[str, str] = Field(
465
        default_factory=dict,
466
        description="Mapping of entity class names to usable names.",
467
    )
468

469
    _tokenizer: Callable = PrivateAttr()
470
    _model: Any = PrivateAttr()
471

472
    def __init__(
473
        self,
474
        model_name: str = DEFAULT_ENTITY_MODEL,
475
        prediction_threshold: float = 0.5,
476
        span_joiner: str = " ",
477
        label_entities: bool = False,
478
        device: Optional[str] = None,
479
        entity_map: Optional[Dict[str, str]] = None,
480
        tokenizer: Optional[Callable[[str], List[str]]] = None,
481
        **kwargs: Any,
482
    ):
483
        """
484
        Entity extractor for extracting entities from text and inserting
485
        into node metadata.
486

487
        Args:
488
            model_name (str):
489
                Name of the SpanMarker model to use.
490
            prediction_threshold (float):
491
                Minimum prediction threshold for entities. Defaults to 0.5.
492
            span_joiner (str):
493
                String to join spans with. Defaults to " ".
494
            label_entities (bool):
495
                Whether to label entities with their type. Setting to true can be
496
                slightly error prone, but can be useful for downstream tasks.
497
                Defaults to False.
498
            device (Optional[str]):
499
                Device to use for SpanMarker model, i.e. "cpu" or "cuda".
500
                Loads onto "cpu" by default.
501
            entity_map (Optional[Dict[str, str]]):
502
                Mapping from entity class name to label.
503
            tokenizer (Optional[Callable[[str], List[str]]]):
504
                Tokenizer to use for splitting text into words.
505
                Defaults to NLTK word_tokenize.
506
        """
507
        try:
508
            from span_marker import SpanMarkerModel
509
        except ImportError:
510
            raise ImportError(
511
                "SpanMarker is not installed. Install with `pip install span-marker`."
512
            )
513

514
        try:
515
            from nltk.tokenize import word_tokenize
516
        except ImportError:
517
            raise ImportError("NLTK is not installed. Install with `pip install nltk`.")
518

519
        self._model = SpanMarkerModel.from_pretrained(model_name)
520
        if device is not None:
521
            self._model = self._model.to(device)
522

523
        self._tokenizer = tokenizer or word_tokenize
524

525
        base_entity_map = DEFAULT_ENTITY_MAP
526
        if entity_map is not None:
527
            base_entity_map.update(entity_map)
528

529
        super().__init__(
530
            model_name=model_name,
531
            prediction_threshold=prediction_threshold,
532
            span_joiner=span_joiner,
533
            label_entities=label_entities,
534
            device=device,
535
            entity_map=base_entity_map,
536
            **kwargs,
537
        )
538

539
    @classmethod
540
    def class_name(cls) -> str:
541
        return "EntityExtractor"
542

543
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
544
        # Extract node-level entity metadata
545
        metadata_list: List[Dict] = [{} for _ in nodes]
546
        metadata_queue: Iterable[int] = get_tqdm_iterable(
547
            range(len(nodes)), self.show_progress, "Extracting entities"
548
        )
549

550
        for i in metadata_queue:
551
            metadata = metadata_list[i]
552
            node_text = nodes[i].get_content(metadata_mode=self.metadata_mode)
553
            words = self._tokenizer(node_text)
554
            spans = self._model.predict(words)
555
            for span in spans:
556
                if span["score"] > self.prediction_threshold:
557
                    ent_label = self.entity_map.get(span["label"], span["label"])
558
                    metadata_label = ent_label if self.label_entities else "entities"
559

560
                    if metadata_label not in metadata:
561
                        metadata[metadata_label] = set()
562

563
                    metadata[metadata_label].add(self.span_joiner.join(span["span"]))
564

565
        # convert metadata from set to list
566
        for metadata in metadata_list:
567
            for key, val in metadata.items():
568
                metadata[key] = list(val)
569

570
        return metadata_list
571

572

573
DEFAULT_EXTRACT_TEMPLATE_STR = """\
574
Here is the content of the section:
575
----------------
576
{context_str}
577
----------------
578
Given the contextual information, extract out a {class_name} object.\
579
"""
580

581

582
class PydanticProgramExtractor(BaseExtractor):
583
    """Pydantic program extractor.
584

585
    Uses an LLM to extract out a Pydantic object. Return attributes of that object
586
    in a dictionary.
587

588
    """
589

590
    program: BasePydanticProgram = Field(
591
        ..., description="Pydantic program to extract."
592
    )
593
    input_key: str = Field(
594
        default="input",
595
        description=(
596
            "Key to use as input to the program (the program "
597
            "template string must expose this key)."
598
        ),
599
    )
600
    extract_template_str: str = Field(
601
        default=DEFAULT_EXTRACT_TEMPLATE_STR,
602
        description="Template to use for extraction.",
603
    )
604

605
    @classmethod
606
    def class_name(cls) -> str:
607
        return "PydanticModelExtractor"
608

609
    async def _acall_program(self, node: BaseNode) -> Dict[str, Any]:
610
        """Call the program on a node."""
611
        if self.is_text_node_only and not isinstance(node, TextNode):
612
            return {}
613

614
        extract_str = self.extract_template_str.format(
615
            context_str=node.get_content(metadata_mode=self.metadata_mode),
616
            class_name=self.program.output_cls.__name__,
617
        )
618

619
        ret_object = await self.program.acall(**{self.input_key: extract_str})
620
        return ret_object.dict()
621

622
    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
623
        """Extract pydantic program."""
624
        program_jobs = []
625
        for node in nodes:
626
            program_jobs.append(self._acall_program(node))
627

628
        metadata_list: List[Dict] = await run_jobs(
629
            program_jobs, show_progress=self.show_progress, workers=self.num_workers
630
        )
631

632
        return metadata_list
633

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.