demo-ml-pennfudanped

Форк
0
/
image_utils.py 
83 строки · 2.1 Кб
1
import matplotlib.pyplot as plt
2
import torch
3
import numpy as np
4

5

6
def show_image(im):
7
    fix_image_processing()
8

9
    plt.figure(figsize=(5,5), dpi=120)
10
    plt.imshow(im)
11
    plt.axis('off')
12
    plt.show()
13

14

15
def fix_image_processing():
16
    import os
17
    os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
18

19

20
# batch = [num_of_masks, x, y]
21
def merge_masks(batch):
22
    return torch.sum(batch, dim=0)
23

24

25
# batch = [num_of_masks, x, y]
26
def merge_masks_with_colors(batch):
27
    color = 1
28
    for mask in batch:
29
        mask[mask > 0] = color
30
        color += 1
31
    return torch.sum(batch, dim=0)
32

33

34
def merge_image_and_masks(image, masks):
35
    merged_masks_with_color_dim = torch \
36
        .stack((masks,masks,masks)) \
37
        .permute(1,2,0)
38

39
    stacked_example = np.hstack((merged_masks_with_color_dim, image))
40
    return stacked_example
41

42

43
# boxes - torch.Size([num_of_masks, 4])
44
def build_box_masks(x, y, boxes):
45
    box_mask = np.zeros((len(boxes), x, y))
46
    counter = 0
47
    for box in boxes:
48
        xmin, ymin, xmax, ymax = box
49
        box_mask[counter, ymin.long():ymax.long(), xmin.long():xmax.long()] = counter + 1
50
        counter += 1
51
    box_mask = torch.tensor(box_mask)
52
    return box_mask
53

54

55
# image - torch.Size([x, y, 3])
56
# masks - torch.Size([x, y])
57
# box_mask_view - torch.Size([x, y, 3])
58
def merge_image_and_masks_boxes(image, masks, box_mask_view):
59

60
    merged_masks_with_color_dim = torch \
61
        .stack((masks,masks,masks)) \
62
        .permute(1,2,0)
63

64
    merged_box_masks_with_color_dim = torch \
65
        .stack((box_mask_view,box_mask_view,box_mask_view)) \
66
        .permute(1,2,0)
67

68
    stacked_example = np.hstack((
69
        merged_masks_with_color_dim,
70
        image,
71
        merged_box_masks_with_color_dim))
72
    return stacked_example
73

74

75
def show_image_grid(rows, cols, image_batch):
76
    img_count = 0
77
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(15, 15))
78
    for i in range(rows):
79
        for j in range(cols):
80
            if img_count < len(image_batch):
81
                axes[i, j].imshow(image_batch[img_count].permute(1, 2, 0))
82
                axes[i, j].axis('off')
83
                img_count += 1
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.