llama-index

Форк
0
93 строки · 3.0 Кб
1
from pathlib import Path
2
from typing import Dict, List, Optional
3

4
from llama_index.legacy.readers.base import BaseReader
5
from llama_index.legacy.schema import Document, ImageDocument
6
from llama_index.legacy.utils import infer_torch_device
7

8

9
class ImageVisionLLMReader(BaseReader):
10
    """Image parser.
11

12
    Caption image using Blip2 (a multimodal VisionLLM similar to GPT4).
13

14
    """
15

16
    def __init__(
17
        self,
18
        parser_config: Optional[Dict] = None,
19
        keep_image: bool = False,
20
        prompt: str = "Question: describe what you see in this image. Answer:",
21
    ):
22
        """Init params."""
23
        if parser_config is None:
24
            try:
25
                import sentencepiece  # noqa
26
                import torch
27
                from PIL import Image  # noqa
28
                from transformers import Blip2ForConditionalGeneration, Blip2Processor
29
            except ImportError:
30
                raise ImportError(
31
                    "Please install extra dependencies that are required for "
32
                    "the ImageCaptionReader: "
33
                    "`pip install torch transformers sentencepiece Pillow`"
34
                )
35

36
            device = infer_torch_device()
37
            dtype = torch.float16 if torch.cuda.is_available() else torch.float32
38
            processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
39
            model = Blip2ForConditionalGeneration.from_pretrained(
40
                "Salesforce/blip2-opt-2.7b", torch_dtype=dtype
41
            )
42
            parser_config = {
43
                "processor": processor,
44
                "model": model,
45
                "device": device,
46
                "dtype": dtype,
47
            }
48

49
        self._parser_config = parser_config
50
        self._keep_image = keep_image
51
        self._prompt = prompt
52

53
    def load_data(
54
        self, file: Path, extra_info: Optional[Dict] = None
55
    ) -> List[Document]:
56
        """Parse file."""
57
        from PIL import Image
58

59
        from llama_index.legacy.img_utils import img_2_b64
60

61
        # load document image
62
        image = Image.open(file)
63
        if image.mode != "RGB":
64
            image = image.convert("RGB")
65

66
        # Encode image into base64 string and keep in document
67
        image_str: Optional[str] = None
68
        if self._keep_image:
69
            image_str = img_2_b64(image)
70

71
        # Parse image into text
72
        model = self._parser_config["model"]
73
        processor = self._parser_config["processor"]
74

75
        device = self._parser_config["device"]
76
        dtype = self._parser_config["dtype"]
77
        model.to(device)
78

79
        # unconditional image captioning
80

81
        inputs = processor(image, self._prompt, return_tensors="pt").to(device, dtype)
82

83
        out = model.generate(**inputs)
84
        text_str = processor.decode(out[0], skip_special_tokens=True)
85

86
        return [
87
            ImageDocument(
88
                text=text_str,
89
                image=image_str,
90
                image_path=str(file),
91
                metadata=extra_info or {},
92
            )
93
        ]
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.