llama-index

Форк
0
98 строк · 3.0 Кб
1
from pathlib import Path
2
from typing import Dict, List, Optional
3

4
from llama_index.legacy.readers.base import BaseReader
5
from llama_index.legacy.schema import Document, ImageDocument
6
from llama_index.legacy.utils import infer_torch_device
7

8

9
class ImageCaptionReader(BaseReader):
10
    """Image parser.
11

12
    Caption image using Blip.
13

14
    """
15

16
    def __init__(
17
        self,
18
        parser_config: Optional[Dict] = None,
19
        keep_image: bool = False,
20
        prompt: Optional[str] = None,
21
    ):
22
        """Init params."""
23
        if parser_config is None:
24
            """Init parser."""
25
            try:
26
                import sentencepiece  # noqa
27
                import torch
28
                from PIL import Image  # noqa
29
                from transformers import BlipForConditionalGeneration, BlipProcessor
30
            except ImportError:
31
                raise ImportError(
32
                    "Please install extra dependencies that are required for "
33
                    "the ImageCaptionReader: "
34
                    "`pip install torch transformers sentencepiece Pillow`"
35
                )
36

37
            device = infer_torch_device()
38
            dtype = torch.float16 if torch.cuda.is_available() else torch.float32
39

40
            processor = BlipProcessor.from_pretrained(
41
                "Salesforce/blip-image-captioning-large"
42
            )
43
            model = BlipForConditionalGeneration.from_pretrained(
44
                "Salesforce/blip-image-captioning-large", torch_dtype=dtype
45
            )
46

47
            parser_config = {
48
                "processor": processor,
49
                "model": model,
50
                "device": device,
51
                "dtype": dtype,
52
            }
53

54
        self._parser_config = parser_config
55
        self._keep_image = keep_image
56
        self._prompt = prompt
57

58
    def load_data(
59
        self, file: Path, extra_info: Optional[Dict] = None
60
    ) -> List[Document]:
61
        """Parse file."""
62
        from PIL import Image
63

64
        from llama_index.legacy.img_utils import img_2_b64
65

66
        # load document image
67
        image = Image.open(file)
68
        if image.mode != "RGB":
69
            image = image.convert("RGB")
70

71
        # Encode image into base64 string and keep in document
72
        image_str: Optional[str] = None
73
        if self._keep_image:
74
            image_str = img_2_b64(image)
75

76
        # Parse image into text
77
        model = self._parser_config["model"]
78
        processor = self._parser_config["processor"]
79

80
        device = self._parser_config["device"]
81
        dtype = self._parser_config["dtype"]
82
        model.to(device)
83

84
        # unconditional image captioning
85

86
        inputs = processor(image, self._prompt, return_tensors="pt").to(device, dtype)
87

88
        out = model.generate(**inputs)
89
        text_str = processor.decode(out[0], skip_special_tokens=True)
90

91
        return [
92
            ImageDocument(
93
                text=text_str,
94
                image=image_str,
95
                image_path=str(file),
96
                metadata=extra_info or {},
97
            )
98
        ]
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.