llama-index
93 строки · 3.0 Кб
1from pathlib import Path2from typing import Dict, List, Optional3
4from llama_index.legacy.readers.base import BaseReader5from llama_index.legacy.schema import Document, ImageDocument6from llama_index.legacy.utils import infer_torch_device7
8
9class ImageVisionLLMReader(BaseReader):10"""Image parser.11
12Caption image using Blip2 (a multimodal VisionLLM similar to GPT4).
13
14"""
15
16def __init__(17self,18parser_config: Optional[Dict] = None,19keep_image: bool = False,20prompt: str = "Question: describe what you see in this image. Answer:",21):22"""Init params."""23if parser_config is None:24try:25import sentencepiece # noqa26import torch27from PIL import Image # noqa28from transformers import Blip2ForConditionalGeneration, Blip2Processor29except ImportError:30raise ImportError(31"Please install extra dependencies that are required for "32"the ImageCaptionReader: "33"`pip install torch transformers sentencepiece Pillow`"34)35
36device = infer_torch_device()37dtype = torch.float16 if torch.cuda.is_available() else torch.float3238processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")39model = Blip2ForConditionalGeneration.from_pretrained(40"Salesforce/blip2-opt-2.7b", torch_dtype=dtype41)42parser_config = {43"processor": processor,44"model": model,45"device": device,46"dtype": dtype,47}48
49self._parser_config = parser_config50self._keep_image = keep_image51self._prompt = prompt52
53def load_data(54self, file: Path, extra_info: Optional[Dict] = None55) -> List[Document]:56"""Parse file."""57from PIL import Image58
59from llama_index.legacy.img_utils import img_2_b6460
61# load document image62image = Image.open(file)63if image.mode != "RGB":64image = image.convert("RGB")65
66# Encode image into base64 string and keep in document67image_str: Optional[str] = None68if self._keep_image:69image_str = img_2_b64(image)70
71# Parse image into text72model = self._parser_config["model"]73processor = self._parser_config["processor"]74
75device = self._parser_config["device"]76dtype = self._parser_config["dtype"]77model.to(device)78
79# unconditional image captioning80
81inputs = processor(image, self._prompt, return_tensors="pt").to(device, dtype)82
83out = model.generate(**inputs)84text_str = processor.decode(out[0], skip_special_tokens=True)85
86return [87ImageDocument(88text=text_str,89image=image_str,90image_path=str(file),91metadata=extra_info or {},92)93]94