llama-index
98 строк · 3.0 Кб
1from pathlib import Path
2from typing import Dict, List, Optional
3
4from llama_index.legacy.readers.base import BaseReader
5from llama_index.legacy.schema import Document, ImageDocument
6from llama_index.legacy.utils import infer_torch_device
7
8
9class ImageCaptionReader(BaseReader):
10"""Image parser.
11
12Caption image using Blip.
13
14"""
15
16def __init__(
17self,
18parser_config: Optional[Dict] = None,
19keep_image: bool = False,
20prompt: Optional[str] = None,
21):
22"""Init params."""
23if parser_config is None:
24"""Init parser."""
25try:
26import sentencepiece # noqa
27import torch
28from PIL import Image # noqa
29from transformers import BlipForConditionalGeneration, BlipProcessor
30except ImportError:
31raise ImportError(
32"Please install extra dependencies that are required for "
33"the ImageCaptionReader: "
34"`pip install torch transformers sentencepiece Pillow`"
35)
36
37device = infer_torch_device()
38dtype = torch.float16 if torch.cuda.is_available() else torch.float32
39
40processor = BlipProcessor.from_pretrained(
41"Salesforce/blip-image-captioning-large"
42)
43model = BlipForConditionalGeneration.from_pretrained(
44"Salesforce/blip-image-captioning-large", torch_dtype=dtype
45)
46
47parser_config = {
48"processor": processor,
49"model": model,
50"device": device,
51"dtype": dtype,
52}
53
54self._parser_config = parser_config
55self._keep_image = keep_image
56self._prompt = prompt
57
58def load_data(
59self, file: Path, extra_info: Optional[Dict] = None
60) -> List[Document]:
61"""Parse file."""
62from PIL import Image
63
64from llama_index.legacy.img_utils import img_2_b64
65
66# load document image
67image = Image.open(file)
68if image.mode != "RGB":
69image = image.convert("RGB")
70
71# Encode image into base64 string and keep in document
72image_str: Optional[str] = None
73if self._keep_image:
74image_str = img_2_b64(image)
75
76# Parse image into text
77model = self._parser_config["model"]
78processor = self._parser_config["processor"]
79
80device = self._parser_config["device"]
81dtype = self._parser_config["dtype"]
82model.to(device)
83
84# unconditional image captioning
85
86inputs = processor(image, self._prompt, return_tensors="pt").to(device, dtype)
87
88out = model.generate(**inputs)
89text_str = processor.decode(out[0], skip_special_tokens=True)
90
91return [
92ImageDocument(
93text=text_str,
94image=image_str,
95image_path=str(file),
96metadata=extra_info or {},
97)
98]
99