llama-index

Форк
0
118 строк · 3.8 Кб
1
"""Image parser.
2

3
Contains parsers for image files.
4

5
"""
6

7
import re
8
from pathlib import Path
9
from typing import Dict, List, Optional
10

11
from llama_index.legacy.readers.base import BaseReader
12
from llama_index.legacy.schema import Document, ImageDocument
13
from llama_index.legacy.utils import infer_torch_device
14

15

16
class ImageReader(BaseReader):
17
    """Image parser.
18

19
    Extract text from images using DONUT.
20

21
    """
22

23
    def __init__(
24
        self,
25
        parser_config: Optional[Dict] = None,
26
        keep_image: bool = False,
27
        parse_text: bool = False,
28
    ):
29
        """Init parser."""
30
        if parser_config is None and parse_text:
31
            try:
32
                import sentencepiece  # noqa
33
                import torch  # noqa
34
                from PIL import Image  # noqa
35
                from transformers import DonutProcessor, VisionEncoderDecoderModel
36
            except ImportError:
37
                raise ImportError(
38
                    "Please install extra dependencies that are required for "
39
                    "the ImageCaptionReader: "
40
                    "`pip install torch transformers sentencepiece Pillow`"
41
                )
42

43
            processor = DonutProcessor.from_pretrained(
44
                "naver-clova-ix/donut-base-finetuned-cord-v2"
45
            )
46
            model = VisionEncoderDecoderModel.from_pretrained(
47
                "naver-clova-ix/donut-base-finetuned-cord-v2"
48
            )
49
            parser_config = {"processor": processor, "model": model}
50

51
        self._parser_config = parser_config
52
        self._keep_image = keep_image
53
        self._parse_text = parse_text
54

55
    def load_data(
56
        self, file: Path, extra_info: Optional[Dict] = None
57
    ) -> List[Document]:
58
        """Parse file."""
59
        from PIL import Image
60

61
        from llama_index.legacy.img_utils import img_2_b64
62

63
        # load document image
64
        image = Image.open(file)
65
        if image.mode != "RGB":
66
            image = image.convert("RGB")
67

68
        # Encode image into base64 string and keep in document
69
        image_str: Optional[str] = None
70
        if self._keep_image:
71
            image_str = img_2_b64(image)
72

73
        # Parse image into text
74
        text_str: str = ""
75
        if self._parse_text:
76
            assert self._parser_config is not None
77
            model = self._parser_config["model"]
78
            processor = self._parser_config["processor"]
79

80
            device = infer_torch_device()
81
            model.to(device)
82

83
            # prepare decoder inputs
84
            task_prompt = "<s_cord-v2>"
85
            decoder_input_ids = processor.tokenizer(
86
                task_prompt, add_special_tokens=False, return_tensors="pt"
87
            ).input_ids
88

89
            pixel_values = processor(image, return_tensors="pt").pixel_values
90

91
            outputs = model.generate(
92
                pixel_values.to(device),
93
                decoder_input_ids=decoder_input_ids.to(device),
94
                max_length=model.decoder.config.max_position_embeddings,
95
                early_stopping=True,
96
                pad_token_id=processor.tokenizer.pad_token_id,
97
                eos_token_id=processor.tokenizer.eos_token_id,
98
                use_cache=True,
99
                num_beams=3,
100
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
101
                return_dict_in_generate=True,
102
            )
103

104
            sequence = processor.batch_decode(outputs.sequences)[0]
105
            sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
106
                processor.tokenizer.pad_token, ""
107
            )
108
            # remove first task start token
109
            text_str = re.sub(r"<.*?>", "", sequence, count=1).strip()
110

111
        return [
112
            ImageDocument(
113
                text=text_str,
114
                image=image_str,
115
                image_path=str(file),
116
                metadata=extra_info or {},
117
            )
118
        ]
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.