llama-index

Форк
0
113 строк · 3.5 Кб
1
"""Slides parser.
2

3
Contains parsers for .pptx files.
4

5
"""
6

7
import os
8
from pathlib import Path
9
from typing import Dict, List, Optional
10

11
from llama_index.legacy.readers.base import BaseReader
12
from llama_index.legacy.schema import Document
13
from llama_index.legacy.utils import infer_torch_device
14

15

16
class PptxReader(BaseReader):
17
    """Powerpoint parser.
18

19
    Extract text, caption images, and specify slides.
20

21
    """
22

23
    def __init__(self) -> None:
24
        """Init parser."""
25
        try:
26
            import torch  # noqa
27
            from PIL import Image  # noqa
28
            from pptx import Presentation  # noqa
29
            from transformers import (
30
                AutoTokenizer,
31
                VisionEncoderDecoderModel,
32
                ViTFeatureExtractor,
33
            )
34
        except ImportError:
35
            raise ImportError(
36
                "Please install extra dependencies that are required for "
37
                "the PptxReader: "
38
                "`pip install torch transformers python-pptx Pillow`"
39
            )
40

41
        model = VisionEncoderDecoderModel.from_pretrained(
42
            "nlpconnect/vit-gpt2-image-captioning"
43
        )
44
        feature_extractor = ViTFeatureExtractor.from_pretrained(
45
            "nlpconnect/vit-gpt2-image-captioning"
46
        )
47
        tokenizer = AutoTokenizer.from_pretrained(
48
            "nlpconnect/vit-gpt2-image-captioning"
49
        )
50

51
        self.parser_config = {
52
            "feature_extractor": feature_extractor,
53
            "model": model,
54
            "tokenizer": tokenizer,
55
        }
56

57
    def caption_image(self, tmp_image_file: str) -> str:
58
        """Generate text caption of image."""
59
        from PIL import Image
60

61
        model = self.parser_config["model"]
62
        feature_extractor = self.parser_config["feature_extractor"]
63
        tokenizer = self.parser_config["tokenizer"]
64

65
        device = infer_torch_device()
66
        model.to(device)
67

68
        max_length = 16
69
        num_beams = 4
70
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
71

72
        i_image = Image.open(tmp_image_file)
73
        if i_image.mode != "RGB":
74
            i_image = i_image.convert(mode="RGB")
75

76
        pixel_values = feature_extractor(
77
            images=[i_image], return_tensors="pt"
78
        ).pixel_values
79
        pixel_values = pixel_values.to(device)
80

81
        output_ids = model.generate(pixel_values, **gen_kwargs)
82

83
        preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
84
        return preds[0].strip()
85

86
    def load_data(
87
        self,
88
        file: Path,
89
        extra_info: Optional[Dict] = None,
90
    ) -> List[Document]:
91
        """Parse file."""
92
        from pptx import Presentation
93

94
        presentation = Presentation(file)
95
        result = ""
96
        for i, slide in enumerate(presentation.slides):
97
            result += f"\n\nSlide #{i}: \n"
98
            for shape in slide.shapes:
99
                if hasattr(shape, "image"):
100
                    image = shape.image
101
                    # get image "file" contents
102
                    image_bytes = image.blob
103
                    # temporarily save the image to feed into model
104
                    image_filename = f"tmp_image.{image.ext}"
105
                    with open(image_filename, "wb") as f:
106
                        f.write(image_bytes)
107
                    result += f"\n Image: {self.caption_image(image_filename)}\n\n"
108

109
                    os.remove(image_filename)
110
                if hasattr(shape, "text"):
111
                    result += f"{shape.text}\n"
112

113
        return [Document(text=result, metadata=extra_info or {})]
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.