llama-index
149 строк · 4.8 Кб
1"""Utils for jupyter notebook."""
2
3import os4from io import BytesIO5from typing import Any, Dict, List, Tuple6
7import matplotlib.pyplot as plt8import requests9from IPython.display import Markdown, display10from PIL import Image11
12from llama_index.legacy.core.response.schema import Response13from llama_index.legacy.img_utils import b64_2_img14from llama_index.legacy.schema import ImageNode, MetadataMode, NodeWithScore15from llama_index.legacy.utils import truncate_text16
17DEFAULT_THUMBNAIL_SIZE = (512, 512)18DEFAULT_IMAGE_MATRIX = (3, 3)19DEFAULT_SHOW_TOP_K = 320
21
22def display_image(img_str: str, size: Tuple[int, int] = DEFAULT_THUMBNAIL_SIZE) -> None:23"""Display base64 encoded image str as image for jupyter notebook."""24img = b64_2_img(img_str)25img.thumbnail(size)26display(img)27
28
29def display_image_uris(30image_paths: List[str],31image_matrix: Tuple[int, int] = DEFAULT_IMAGE_MATRIX,32top_k: int = DEFAULT_SHOW_TOP_K,33) -> None:34"""Display base64 encoded image str as image for jupyter notebook."""35images_shown = 036plt.figure(figsize=(16, 9))37for img_path in image_paths[:top_k]:38if os.path.isfile(img_path):39image = Image.open(img_path)40
41plt.subplot(image_matrix[0], image_matrix[1], images_shown + 1)42plt.imshow(image)43plt.xticks([])44plt.yticks([])45
46images_shown += 147if images_shown >= image_matrix[0] * image_matrix[1]:48break49
50
51def display_source_node(52source_node: NodeWithScore,53source_length: int = 100,54show_source_metadata: bool = False,55metadata_mode: MetadataMode = MetadataMode.NONE,56) -> None:57"""Display source node for jupyter notebook."""58source_text_fmt = truncate_text(59source_node.node.get_content(metadata_mode=metadata_mode).strip(), source_length60)61text_md = (62f"**Node ID:** {source_node.node.node_id}<br>"63f"**Similarity:** {source_node.score}<br>"64f"**Text:** {source_text_fmt}<br>"65)66if show_source_metadata:67text_md += f"**Metadata:** {source_node.node.metadata}<br>"68if isinstance(source_node.node, ImageNode):69text_md += "**Image:**"70
71display(Markdown(text_md))72if isinstance(source_node.node, ImageNode) and source_node.node.image is not None:73display_image(source_node.node.image)74
75
76def display_metadata(metadata: Dict[str, Any]) -> None:77"""Display metadata for jupyter notebook."""78display(metadata)79
80
81def display_response(82response: Response,83source_length: int = 100,84show_source: bool = False,85show_metadata: bool = False,86show_source_metadata: bool = False,87) -> None:88"""Display response for jupyter notebook."""89if response.response is None:90response_text = "None"91else:92response_text = response.response.strip()93
94display(Markdown(f"**`Final Response:`** {response_text}"))95if show_source:96for ind, source_node in enumerate(response.source_nodes):97display(Markdown("---"))98display(99Markdown(f"**`Source Node {ind + 1}/{len(response.source_nodes)}`**")100)101display_source_node(102source_node,103source_length=source_length,104show_source_metadata=show_source_metadata,105)106if show_metadata:107if response.metadata is not None:108display_metadata(response.metadata)109
110
111def display_query_and_multimodal_response(112query_str: str, response: Response, plot_height: int = 2, plot_width: int = 5113) -> None:114"""For displaying a query and its multi-modal response."""115if response.metadata:116image_nodes = response.metadata["image_nodes"] or []117else:118image_nodes = []119num_subplots = len(image_nodes)120
121f, axarr = plt.subplots(1, num_subplots)122f.set_figheight(plot_height)123f.set_figwidth(plot_width)124ix = 0125for ix, scored_img_node in enumerate(image_nodes):126img_node = scored_img_node.node127image = None128if img_node.image_url:129img_response = requests.get(img_node.image_url)130image = Image.open(BytesIO(img_response.content))131elif img_node.image_path:132image = Image.open(img_node.image_path).convert("RGB")133else:134raise ValueError(135"A retrieved image must have image_path or image_url specified."136)137if num_subplots > 1:138axarr[ix].imshow(image)139axarr[ix].set_title(f"Retrieved Position: {ix}", pad=10, fontsize=9)140else:141axarr.imshow(image)142axarr.set_title(f"Retrieved Position: {ix}", pad=10, fontsize=9)143
144f.tight_layout()145print(f"Query: {query_str}\n=======")146print(f"Retrieved Images:\n")147plt.show()148print("=======")149print(f"Response: {response.response}\n=======\n")150