llama-index

Форк
0
149 строк · 4.8 Кб
1
"""Utils for jupyter notebook."""
2

3
import os
4
from io import BytesIO
5
from typing import Any, Dict, List, Tuple
6

7
import matplotlib.pyplot as plt
8
import requests
9
from IPython.display import Markdown, display
10
from PIL import Image
11

12
from llama_index.legacy.core.response.schema import Response
13
from llama_index.legacy.img_utils import b64_2_img
14
from llama_index.legacy.schema import ImageNode, MetadataMode, NodeWithScore
15
from llama_index.legacy.utils import truncate_text
16

17
DEFAULT_THUMBNAIL_SIZE = (512, 512)
18
DEFAULT_IMAGE_MATRIX = (3, 3)
19
DEFAULT_SHOW_TOP_K = 3
20

21

22
def display_image(img_str: str, size: Tuple[int, int] = DEFAULT_THUMBNAIL_SIZE) -> None:
23
    """Display base64 encoded image str as image for jupyter notebook."""
24
    img = b64_2_img(img_str)
25
    img.thumbnail(size)
26
    display(img)
27

28

29
def display_image_uris(
30
    image_paths: List[str],
31
    image_matrix: Tuple[int, int] = DEFAULT_IMAGE_MATRIX,
32
    top_k: int = DEFAULT_SHOW_TOP_K,
33
) -> None:
34
    """Display base64 encoded image str as image for jupyter notebook."""
35
    images_shown = 0
36
    plt.figure(figsize=(16, 9))
37
    for img_path in image_paths[:top_k]:
38
        if os.path.isfile(img_path):
39
            image = Image.open(img_path)
40

41
            plt.subplot(image_matrix[0], image_matrix[1], images_shown + 1)
42
            plt.imshow(image)
43
            plt.xticks([])
44
            plt.yticks([])
45

46
            images_shown += 1
47
            if images_shown >= image_matrix[0] * image_matrix[1]:
48
                break
49

50

51
def display_source_node(
52
    source_node: NodeWithScore,
53
    source_length: int = 100,
54
    show_source_metadata: bool = False,
55
    metadata_mode: MetadataMode = MetadataMode.NONE,
56
) -> None:
57
    """Display source node for jupyter notebook."""
58
    source_text_fmt = truncate_text(
59
        source_node.node.get_content(metadata_mode=metadata_mode).strip(), source_length
60
    )
61
    text_md = (
62
        f"**Node ID:** {source_node.node.node_id}<br>"
63
        f"**Similarity:** {source_node.score}<br>"
64
        f"**Text:** {source_text_fmt}<br>"
65
    )
66
    if show_source_metadata:
67
        text_md += f"**Metadata:** {source_node.node.metadata}<br>"
68
    if isinstance(source_node.node, ImageNode):
69
        text_md += "**Image:**"
70

71
    display(Markdown(text_md))
72
    if isinstance(source_node.node, ImageNode) and source_node.node.image is not None:
73
        display_image(source_node.node.image)
74

75

76
def display_metadata(metadata: Dict[str, Any]) -> None:
77
    """Display metadata for jupyter notebook."""
78
    display(metadata)
79

80

81
def display_response(
82
    response: Response,
83
    source_length: int = 100,
84
    show_source: bool = False,
85
    show_metadata: bool = False,
86
    show_source_metadata: bool = False,
87
) -> None:
88
    """Display response for jupyter notebook."""
89
    if response.response is None:
90
        response_text = "None"
91
    else:
92
        response_text = response.response.strip()
93

94
    display(Markdown(f"**`Final Response:`** {response_text}"))
95
    if show_source:
96
        for ind, source_node in enumerate(response.source_nodes):
97
            display(Markdown("---"))
98
            display(
99
                Markdown(f"**`Source Node {ind + 1}/{len(response.source_nodes)}`**")
100
            )
101
            display_source_node(
102
                source_node,
103
                source_length=source_length,
104
                show_source_metadata=show_source_metadata,
105
            )
106
    if show_metadata:
107
        if response.metadata is not None:
108
            display_metadata(response.metadata)
109

110

111
def display_query_and_multimodal_response(
112
    query_str: str, response: Response, plot_height: int = 2, plot_width: int = 5
113
) -> None:
114
    """For displaying a query and its multi-modal response."""
115
    if response.metadata:
116
        image_nodes = response.metadata["image_nodes"] or []
117
    else:
118
        image_nodes = []
119
    num_subplots = len(image_nodes)
120

121
    f, axarr = plt.subplots(1, num_subplots)
122
    f.set_figheight(plot_height)
123
    f.set_figwidth(plot_width)
124
    ix = 0
125
    for ix, scored_img_node in enumerate(image_nodes):
126
        img_node = scored_img_node.node
127
        image = None
128
        if img_node.image_url:
129
            img_response = requests.get(img_node.image_url)
130
            image = Image.open(BytesIO(img_response.content))
131
        elif img_node.image_path:
132
            image = Image.open(img_node.image_path).convert("RGB")
133
        else:
134
            raise ValueError(
135
                "A retrieved image must have image_path or image_url specified."
136
            )
137
        if num_subplots > 1:
138
            axarr[ix].imshow(image)
139
            axarr[ix].set_title(f"Retrieved Position: {ix}", pad=10, fontsize=9)
140
        else:
141
            axarr.imshow(image)
142
            axarr.set_title(f"Retrieved Position: {ix}", pad=10, fontsize=9)
143

144
    f.tight_layout()
145
    print(f"Query: {query_str}\n=======")
146
    print(f"Retrieved Images:\n")
147
    plt.show()
148
    print("=======")
149
    print(f"Response: {response.response}\n=======\n")
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.