llama-index
118 строк · 3.8 Кб
1"""Image parser.
2
3Contains parsers for image files.
4
5"""
6
7import re8from pathlib import Path9from typing import Dict, List, Optional10
11from llama_index.legacy.readers.base import BaseReader12from llama_index.legacy.schema import Document, ImageDocument13from llama_index.legacy.utils import infer_torch_device14
15
16class ImageReader(BaseReader):17"""Image parser.18
19Extract text from images using DONUT.
20
21"""
22
23def __init__(24self,25parser_config: Optional[Dict] = None,26keep_image: bool = False,27parse_text: bool = False,28):29"""Init parser."""30if parser_config is None and parse_text:31try:32import sentencepiece # noqa33import torch # noqa34from PIL import Image # noqa35from transformers import DonutProcessor, VisionEncoderDecoderModel36except ImportError:37raise ImportError(38"Please install extra dependencies that are required for "39"the ImageCaptionReader: "40"`pip install torch transformers sentencepiece Pillow`"41)42
43processor = DonutProcessor.from_pretrained(44"naver-clova-ix/donut-base-finetuned-cord-v2"45)46model = VisionEncoderDecoderModel.from_pretrained(47"naver-clova-ix/donut-base-finetuned-cord-v2"48)49parser_config = {"processor": processor, "model": model}50
51self._parser_config = parser_config52self._keep_image = keep_image53self._parse_text = parse_text54
55def load_data(56self, file: Path, extra_info: Optional[Dict] = None57) -> List[Document]:58"""Parse file."""59from PIL import Image60
61from llama_index.legacy.img_utils import img_2_b6462
63# load document image64image = Image.open(file)65if image.mode != "RGB":66image = image.convert("RGB")67
68# Encode image into base64 string and keep in document69image_str: Optional[str] = None70if self._keep_image:71image_str = img_2_b64(image)72
73# Parse image into text74text_str: str = ""75if self._parse_text:76assert self._parser_config is not None77model = self._parser_config["model"]78processor = self._parser_config["processor"]79
80device = infer_torch_device()81model.to(device)82
83# prepare decoder inputs84task_prompt = "<s_cord-v2>"85decoder_input_ids = processor.tokenizer(86task_prompt, add_special_tokens=False, return_tensors="pt"87).input_ids88
89pixel_values = processor(image, return_tensors="pt").pixel_values90
91outputs = model.generate(92pixel_values.to(device),93decoder_input_ids=decoder_input_ids.to(device),94max_length=model.decoder.config.max_position_embeddings,95early_stopping=True,96pad_token_id=processor.tokenizer.pad_token_id,97eos_token_id=processor.tokenizer.eos_token_id,98use_cache=True,99num_beams=3,100bad_words_ids=[[processor.tokenizer.unk_token_id]],101return_dict_in_generate=True,102)103
104sequence = processor.batch_decode(outputs.sequences)[0]105sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(106processor.tokenizer.pad_token, ""107)108# remove first task start token109text_str = re.sub(r"<.*?>", "", sequence, count=1).strip()110
111return [112ImageDocument(113text=text_str,114image=image_str,115image_path=str(file),116metadata=extra_info or {},117)118]119