pytorch

Форк
0
126 строк · 4.1 Кб
1
# mypy: allow-untyped-defs
2
import numpy as np
3

4

5
# Functions for converting
6
def figure_to_image(figures, close=True):
7
    """Render matplotlib figure to numpy format.
8

9
    Note that this requires the ``matplotlib`` package.
10

11
    Args:
12
        figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures
13
        close (bool): Flag to automatically close the figure
14

15
    Returns:
16
        numpy.array: image in [CHW] order
17
    """
18
    import matplotlib.pyplot as plt
19
    import matplotlib.backends.backend_agg as plt_backend_agg
20

21
    def render_to_rgb(figure):
22
        canvas = plt_backend_agg.FigureCanvasAgg(figure)
23
        canvas.draw()
24
        data: np.ndarray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
25
        w, h = figure.canvas.get_width_height()
26
        image_hwc = data.reshape([h, w, 4])[:, :, 0:3]
27
        image_chw = np.moveaxis(image_hwc, source=2, destination=0)
28
        if close:
29
            plt.close(figure)
30
        return image_chw
31

32
    if isinstance(figures, list):
33
        images = [render_to_rgb(figure) for figure in figures]
34
        return np.stack(images)
35
    else:
36
        image = render_to_rgb(figures)
37
        return image
38

39

40
def _prepare_video(V):
41
    """
42
    Convert a 5D tensor into 4D tensor.
43

44
    Convesrion is done from [batchsize, time(frame), channel(color), height, width]  (5D tensor)
45
    to [time(frame), new_width, new_height, channel] (4D tensor).
46

47
    A batch of images are spreaded to a grid, which forms a frame.
48
    e.g. Video with batchsize 16 will have a 4x4 grid.
49
    """
50
    b, t, c, h, w = V.shape
51

52
    if V.dtype == np.uint8:
53
        V = np.float32(V) / 255.0
54

55
    def is_power2(num):
56
        return num != 0 and ((num & (num - 1)) == 0)
57

58
    # pad to nearest power of 2, all at once
59
    if not is_power2(V.shape[0]):
60
        len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
61
        V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
62

63
    n_rows = 2 ** ((b.bit_length() - 1) // 2)
64
    n_cols = V.shape[0] // n_rows
65

66
    V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))
67
    V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3))
68
    V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c))
69

70
    return V
71

72

73
def make_grid(I, ncols=8):
74
    # I: N1HW or N3HW
75
    assert isinstance(I, np.ndarray), "plugin error, should pass numpy array here"
76
    if I.shape[1] == 1:
77
        I = np.concatenate([I, I, I], 1)
78
    assert I.ndim == 4 and I.shape[1] == 3
79
    nimg = I.shape[0]
80
    H = I.shape[2]
81
    W = I.shape[3]
82
    ncols = min(nimg, ncols)
83
    nrows = int(np.ceil(float(nimg) / ncols))
84
    canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype)
85
    i = 0
86
    for y in range(nrows):
87
        for x in range(ncols):
88
            if i >= nimg:
89
                break
90
            canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i]
91
            i = i + 1
92
    return canvas
93

94
    # if modality == 'IMG':
95
    #     if x.dtype == np.uint8:
96
    #         x = x.astype(np.float32) / 255.0
97

98

99
def convert_to_HWC(tensor, input_format):  # tensor: numpy array
100
    assert len(set(input_format)) == len(
101
        input_format
102
    ), f"You can not use the same dimension shordhand twice.         input_format: {input_format}"
103
    assert len(tensor.shape) == len(
104
        input_format
105
    ), f"size of input tensor and input format are different. \
106
        tensor shape: {tensor.shape}, input_format: {input_format}"
107
    input_format = input_format.upper()
108

109
    if len(input_format) == 4:
110
        index = [input_format.find(c) for c in "NCHW"]
111
        tensor_NCHW = tensor.transpose(index)
112
        tensor_CHW = make_grid(tensor_NCHW)
113
        return tensor_CHW.transpose(1, 2, 0)
114

115
    if len(input_format) == 3:
116
        index = [input_format.find(c) for c in "HWC"]
117
        tensor_HWC = tensor.transpose(index)
118
        if tensor_HWC.shape[2] == 1:
119
            tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2)
120
        return tensor_HWC
121

122
    if len(input_format) == 2:
123
        index = [input_format.find(c) for c in "HW"]
124
        tensor = tensor.transpose(index)
125
        tensor = np.stack([tensor, tensor, tensor], 2)
126
        return tensor
127

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.