BasicSR

Форк
0
/
flow_util.py 
170 строк · 6.0 Кб
1
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py  # noqa: E501
2
import cv2
3
import numpy as np
4
import os
5

6

7
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8
    """Read an optical flow map.
9

10
    Args:
11
        flow_path (ndarray or str): Flow path.
12
        quantize (bool): whether to read quantized pair, if set to True,
13
            remaining args will be passed to :func:`dequantize_flow`.
14
        concat_axis (int): The axis that dx and dy are concatenated,
15
            can be either 0 or 1. Ignored if quantize is False.
16

17
    Returns:
18
        ndarray: Optical flow represented as a (h, w, 2) numpy array
19
    """
20
    if quantize:
21
        assert concat_axis in [0, 1]
22
        cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23
        if cat_flow.ndim != 2:
24
            raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
25
        assert cat_flow.shape[concat_axis] % 2 == 0
26
        dx, dy = np.split(cat_flow, 2, axis=concat_axis)
27
        flow = dequantize_flow(dx, dy, *args, **kwargs)
28
    else:
29
        with open(flow_path, 'rb') as f:
30
            try:
31
                header = f.read(4).decode('utf-8')
32
            except Exception:
33
                raise IOError(f'Invalid flow file: {flow_path}')
34
            else:
35
                if header != 'PIEH':
36
                    raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
37

38
            w = np.fromfile(f, np.int32, 1).squeeze()
39
            h = np.fromfile(f, np.int32, 1).squeeze()
40
            flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
41

42
    return flow.astype(np.float32)
43

44

45
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
46
    """Write optical flow to file.
47

48
    If the flow is not quantized, it will be saved as a .flo file losslessly,
49
    otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
50
    will be concatenated horizontally into a single image if quantize is True.)
51

52
    Args:
53
        flow (ndarray): (h, w, 2) array of optical flow.
54
        filename (str): Output filepath.
55
        quantize (bool): Whether to quantize the flow and save it to 2 jpeg
56
            images. If set to True, remaining args will be passed to
57
            :func:`quantize_flow`.
58
        concat_axis (int): The axis that dx and dy are concatenated,
59
            can be either 0 or 1. Ignored if quantize is False.
60
    """
61
    if not quantize:
62
        with open(filename, 'wb') as f:
63
            f.write('PIEH'.encode('utf-8'))
64
            np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
65
            flow = flow.astype(np.float32)
66
            flow.tofile(f)
67
            f.flush()
68
    else:
69
        assert concat_axis in [0, 1]
70
        dx, dy = quantize_flow(flow, *args, **kwargs)
71
        dxdy = np.concatenate((dx, dy), axis=concat_axis)
72
        os.makedirs(os.path.dirname(filename), exist_ok=True)
73
        cv2.imwrite(filename, dxdy)
74

75

76
def quantize_flow(flow, max_val=0.02, norm=True):
77
    """Quantize flow to [0, 255].
78

79
    After this step, the size of flow will be much smaller, and can be
80
    dumped as jpeg images.
81

82
    Args:
83
        flow (ndarray): (h, w, 2) array of optical flow.
84
        max_val (float): Maximum value of flow, values beyond
85
                        [-max_val, max_val] will be truncated.
86
        norm (bool): Whether to divide flow values by image width/height.
87

88
    Returns:
89
        tuple[ndarray]: Quantized dx and dy.
90
    """
91
    h, w, _ = flow.shape
92
    dx = flow[..., 0]
93
    dy = flow[..., 1]
94
    if norm:
95
        dx = dx / w  # avoid inplace operations
96
        dy = dy / h
97
    # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
98
    flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
99
    return tuple(flow_comps)
100

101

102
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
103
    """Recover from quantized flow.
104

105
    Args:
106
        dx (ndarray): Quantized dx.
107
        dy (ndarray): Quantized dy.
108
        max_val (float): Maximum value used when quantizing.
109
        denorm (bool): Whether to multiply flow values with width/height.
110

111
    Returns:
112
        ndarray: Dequantized flow.
113
    """
114
    assert dx.shape == dy.shape
115
    assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
116

117
    dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
118

119
    if denorm:
120
        dx *= dx.shape[1]
121
        dy *= dx.shape[0]
122
    flow = np.dstack((dx, dy))
123
    return flow
124

125

126
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
127
    """Quantize an array of (-inf, inf) to [0, levels-1].
128

129
    Args:
130
        arr (ndarray): Input array.
131
        min_val (scalar): Minimum value to be clipped.
132
        max_val (scalar): Maximum value to be clipped.
133
        levels (int): Quantization levels.
134
        dtype (np.type): The type of the quantized array.
135

136
    Returns:
137
        tuple: Quantized array.
138
    """
139
    if not (isinstance(levels, int) and levels > 1):
140
        raise ValueError(f'levels must be a positive integer, but got {levels}')
141
    if min_val >= max_val:
142
        raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
143

144
    arr = np.clip(arr, min_val, max_val) - min_val
145
    quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
146

147
    return quantized_arr
148

149

150
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
151
    """Dequantize an array.
152

153
    Args:
154
        arr (ndarray): Input array.
155
        min_val (scalar): Minimum value to be clipped.
156
        max_val (scalar): Maximum value to be clipped.
157
        levels (int): Quantization levels.
158
        dtype (np.type): The type of the dequantized array.
159

160
    Returns:
161
        tuple: Dequantized array.
162
    """
163
    if not (isinstance(levels, int) and levels > 1):
164
        raise ValueError(f'levels must be a positive integer, but got {levels}')
165
    if min_val >= max_val:
166
        raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
167

168
    dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
169

170
    return dequantized_arr
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.