pytorch

Форк
0
/
image_input_op_test.py 
428 строк · 16.9 Кб
1

2

3

4

5

6
import unittest
7
try:
8
    import cv2
9
    import lmdb
10
except ImportError:
11
    pass  # Handled below
12

13
from PIL import Image
14
import numpy as np
15
import shutil
16
import io
17
import sys
18
import tempfile
19

20
# TODO: This test does not test scaling because
21
# the algorithms used by OpenCV in the C and Python
22
# version seem to differ slightly. It does test
23
# most other features
24

25
from hypothesis import given, settings, Verbosity
26
import hypothesis.strategies as st
27

28
from caffe2.proto import caffe2_pb2
29
import caffe2.python.hypothesis_test_util as hu
30

31
from caffe2.python import workspace, core
32

33

34
# Verification routines (applies transformations to image to
35
# verify if the operator produces same result)
36
def verify_apply_bounding_box(img, box):
37
    import skimage.util
38
    if any(type(box[f]) is not int or np.isnan(box[f] or box[f] < 0)
39
           for f in range(0, 4)):
40
        return img
41
    # Box is ymin, xmin, bound_height, bound_width
42
    y_bounds = (box[0], img.shape[0] - box[0] - box[2])
43
    x_bounds = (box[1], img.shape[1] - box[1] - box[3])
44
    c_bounds = (0, 0)
45

46
    if any(el < 0 for el in list(y_bounds) + list(x_bounds) + list(c_bounds)):
47
        return img
48

49
    bboxed = skimage.util.crop(img, (y_bounds, x_bounds, c_bounds))
50
    return bboxed
51

52

53
# This function is called but not used. It will trip on assert False if
54
# the arguments are wrong (improper example)
55
def verify_rescale(img, minsize):
56
    # Here we use OpenCV transformation to match the C code
57
    scale_amount = float(minsize) / min(img.shape[0], img.shape[1])
58
    if scale_amount <= 1.0:
59
        return img
60

61
    print("Scale amount is %f -- should be < 1.0; got shape %s" %
62
          (scale_amount, str(img.shape)))
63
    assert False
64
    img_cv = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
65
    output_shape = (int(np.ceil(scale_amount * img_cv.shape[0])),
66
                    int(np.ceil(scale_amount * img_cv.shape[1])))
67
    resized = cv2.resize(img_cv,
68
                         dsize=output_shape,
69
                         interpolation=cv2.INTER_AREA)
70

71
    resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
72
    assert resized.shape[0] >= minsize
73
    assert resized.shape[1] >= minsize
74
    return resized
75

76

77
def verify_crop(img, crop):
78
    import skimage.util
79
    assert img.shape[0] >= crop
80
    assert img.shape[1] >= crop
81
    y_offset = 0
82
    if img.shape[0] > crop:
83
        y_offset = (img.shape[0] - crop) // 2
84

85
    x_offset = 0
86
    if img.shape[1] > crop:
87
        x_offset = (img.shape[1] - crop) // 2
88

89
    y_bounds = (y_offset, img.shape[0] - crop - y_offset)
90
    x_bounds = (x_offset, img.shape[1] - crop - x_offset)
91
    c_bounds = (0, 0)
92
    cropped = skimage.util.crop(img, (y_bounds, x_bounds, c_bounds))
93
    assert cropped.shape[0] == crop
94
    assert cropped.shape[1] == crop
95
    return cropped
96

97

98
def verify_color_normalize(img, means, stds):
99
    # Note the RGB/BGR inversion
100
    # Operate on integers like the C version
101
    img = img * 255.0
102
    img[:, :, 0] = (img[:, :, 0] - means[2]) / stds[2]
103
    img[:, :, 1] = (img[:, :, 1] - means[1]) / stds[1]
104
    img[:, :, 2] = (img[:, :, 2] - means[0]) / stds[0]
105
    return img * (1.0 / 255.0)
106

107

108
# Printing function (for debugging)
109
def caffe2_img(img):
110
    # Convert RGB to BGR
111
    img = img[:, :, (2, 1, 0)]
112
    # Convert HWC to CHW
113
    img = img.swapaxes(1, 2).swapaxes(0, 1)
114
    img = img * 255.0
115
    return img.astype(np.int32)
116

117

118
# Bounding box is ymin, xmin, height, width
119
def create_test(output_dir, width, height, default_bound, minsize, crop, means,
120
                stds, count, label_type, num_labels, output1=None,
121
                output2_size=None):
122
    print("Creating a temporary lmdb database of %d pictures..." % (count))
123

124
    if default_bound is None:
125
        default_bound = [-1] * 4
126

127
    LMDB_MAP_SIZE = 1 << 40
128
    env = lmdb.open(output_dir, map_size=LMDB_MAP_SIZE, subdir=True)
129
    index = 0
130
    # Create images and the expected results
131
    expected_results = []
132
    with env.begin(write=True) as txn:
133
        while index < count:
134
            img_array = np.random.random_integers(
135
                0, 255, [height, width, 3]).astype(np.uint8)
136
            img_obj = Image.fromarray(img_array)
137
            img_str = io.BytesIO()
138
            img_obj.save(img_str, 'PNG')
139

140
            # Create a random bounding box for every other image
141
            # ymin, xmin, bound_height, bound_width
142
            # TODO: To ensure that we never need to scale, we
143
            # ensure that the bounding-box is larger than the
144
            # minsize parameter
145
            bounding_box = list(default_bound)
146
            do_default_bound = True
147
            if index % 2 == 0:
148
                if height > minsize and width > minsize:
149
                    do_default_bound = False
150
                    bounding_box[0:2] = [np.random.randint(a) for a in
151
                                         (height - minsize, width - minsize)]
152
                    bounding_box[2:4] = [np.random.randint(a) + minsize for a in
153
                                         (height - bounding_box[0] - minsize + 1,
154
                                          width - bounding_box[1] - minsize + 1)]
155
                    # print("Bounding box is %s" % (str(bounding_box)))
156
            # Create expected result
157
            img_expected = img_array.astype(np.float32) * (1.0 / 255.0)
158
            # print("Orig image: %s" % (str(caffe2_img(img_expected))))
159
            img_expected = verify_apply_bounding_box(
160
                img_expected,
161
                bounding_box)
162
            # print("Bounded image: %s" % (str(caffe2_img(img_expected))))
163

164
            img_expected = verify_rescale(img_expected, minsize)
165

166
            img_expected = verify_crop(img_expected, crop)
167
            # print("Crop image: %s" % (str(caffe2_img(img_expected))))
168

169
            img_expected = verify_color_normalize(img_expected, means, stds)
170
            # print("Color image: %s" % (str(caffe2_img(img_expected))))
171

172
            tensor_protos = caffe2_pb2.TensorProtos()
173
            image_tensor = tensor_protos.protos.add()
174
            image_tensor.data_type = 4  # string data
175
            image_tensor.string_data.append(img_str.getvalue())
176
            img_str.close()
177

178
            label_tensor = tensor_protos.protos.add()
179
            label_tensor.data_type = 2  # int32 data
180
            assert (label_type >= 0 and label_type <= 3)
181
            if label_type == 0:
182
                label_tensor.int32_data.append(index)
183
                expected_label = index
184
            elif label_type == 1:
185
                binary_labels = np.random.randint(2, size=num_labels)
186
                for idx, val in enumerate(binary_labels.tolist()):
187
                    if val == 1:
188
                        label_tensor.int32_data.append(idx)
189
                expected_label = binary_labels
190
            elif label_type == 2:
191
                embedding_label = np.random.randint(100, size=num_labels)
192
                for _idx, val in enumerate(embedding_label.tolist()):
193
                    label_tensor.int32_data.append(val)
194
                expected_label = embedding_label
195
            elif label_type == 3:
196
                weight_tensor = tensor_protos.protos.add()
197
                weight_tensor.data_type = 1  # float weights
198
                binary_labels = np.random.randint(2, size=num_labels)
199
                expected_label = np.zeros(num_labels).astype(np.float32)
200
                for idx, val in enumerate(binary_labels.tolist()):
201
                    expected_label[idx] = val * idx
202
                    if val == 1:
203
                        label_tensor.int32_data.append(idx)
204
                        weight_tensor.float_data.append(idx)
205

206
            if output1:
207
                output1_tensor = tensor_protos.protos.add()
208
                output1_tensor.data_type = 1  # float data
209
                output1_tensor.float_data.append(output1)
210

211
            output2 = []
212
            if output2_size:
213
                output2_tensor = tensor_protos.protos.add()
214
                output2_tensor.data_type = 2  # int32 data
215
                values = np.random.randint(1024, size=output2_size)
216
                for val in values.tolist():
217
                    output2.append(val)
218
                    output2_tensor.int32_data.append(val)
219

220
            expected_results.append(
221
                [caffe2_img(img_expected), expected_label, output1, output2])
222

223
            if not do_default_bound:
224
                bounding_tensor = tensor_protos.protos.add()
225
                bounding_tensor.data_type = 2  # int32 data
226
                bounding_tensor.int32_data.extend(bounding_box)
227

228
            txn.put(
229
                '{}'.format(index).encode('ascii'),
230
                tensor_protos.SerializeToString()
231
            )
232
            index = index + 1
233
        # End while
234
    # End with
235
    return expected_results
236

237

238
def run_test(
239
        size_tuple, means, stds, label_type, num_labels, is_test, scale_jitter_type,
240
        color_jitter, color_lighting, dc, validator, output1=None, output2_size=None):
241
    # TODO: Does not test on GPU and does not test use_gpu_transform
242
    # WARNING: Using ModelHelper automatically does NHWC to NCHW
243
    # transformation if needed.
244
    width, height, minsize, crop = size_tuple
245
    means = [float(m) for m in means]
246
    stds = [float(s) for s in stds]
247
    out_dir = tempfile.mkdtemp()
248
    count_images = 2  # One with bounding box and one without
249
    expected_images = create_test(
250
        out_dir,
251
        width=width,
252
        height=height,
253
        default_bound=(3, 5, height - 3, width - 5),
254
        minsize=minsize,
255
        crop=crop,
256
        means=means,
257
        stds=stds,
258
        count=count_images,
259
        label_type=label_type,
260
        num_labels=num_labels,
261
        output1=output1,
262
        output2_size=output2_size
263
    )
264
    for device_option in dc:
265
        with hu.temp_workspace():
266
            reader_net = core.Net('reader')
267
            reader_net.CreateDB(
268
                [],
269
                'DB',
270
                db=out_dir,
271
                db_type="lmdb"
272
            )
273
            workspace.RunNetOnce(reader_net)
274
            outputs = ['data', 'label']
275
            output_sizes = []
276
            if output1:
277
                outputs.append('output1')
278
                output_sizes.append(1)
279
            if output2_size:
280
                outputs.append('output2')
281
                output_sizes.append(output2_size)
282
            imageop = core.CreateOperator(
283
                'ImageInput',
284
                ['DB'],
285
                outputs,
286
                batch_size=count_images,
287
                color=3,
288
                minsize=minsize,
289
                crop=crop,
290
                is_test=is_test,
291
                bounding_ymin=3,
292
                bounding_xmin=5,
293
                bounding_height=height - 3,
294
                bounding_width=width - 5,
295
                mean_per_channel=means,
296
                std_per_channel=stds,
297
                use_gpu_transform=(device_option.device_type == 1),
298
                label_type=label_type,
299
                num_labels=num_labels,
300
                output_sizes=output_sizes,
301
                scale_jitter_type=scale_jitter_type,
302
                color_jitter=color_jitter,
303
                color_lighting=color_lighting
304
            )
305

306
            imageop.device_option.CopyFrom(device_option)
307
            main_net = core.Net('main')
308
            main_net.Proto().op.extend([imageop])
309
            workspace.RunNetOnce(main_net)
310
            validator(expected_images, device_option, count_images)
311
            # End for
312
        # End with
313
    # End for
314
    shutil.rmtree(out_dir)
315
# end run_test
316

317

318
@unittest.skipIf('cv2' not in sys.modules, 'python-opencv is not installed')
319
@unittest.skipIf('lmdb' not in sys.modules, 'python-lmdb is not installed')
320
class TestImport(hu.HypothesisTestCase):
321
    def validate_image_and_label(
322
            self, expected_images, device_option, count_images, label_type,
323
            is_test, scale_jitter_type, color_jitter, color_lighting):
324
        l = workspace.FetchBlob('label')
325
        result = workspace.FetchBlob('data').astype(np.int32)
326
        # If we don't use_gpu_transform, the output is in NHWC
327
        # Our reference output is CHW so we swap
328
        if device_option.device_type != 1:
329
            expected = [img.swapaxes(0, 1).swapaxes(1, 2) for
330
                        (img, _, _, _) in expected_images]
331
        else:
332
            expected = [img for (img, _, _, _) in expected_images]
333
        for i in range(count_images):
334
            if label_type == 0:
335
                self.assertEqual(l[i], expected_images[i][1])
336
            else:
337
                self.assertEqual(
338
                    (l[i] - expected_images[i][1] > 0).sum(), 0)
339
            if is_test == 0:
340
                # when traing data preparation is randomized (e.g. random cropping,
341
                # Inception-style random sized cropping, color jittering,
342
                # color lightin), we only compare blob shape
343
                for (s1, s2) in zip(expected[i].shape, result[i].shape):
344
                    self.assertEqual(s1, s2)
345
            else:
346
                self.assertEqual((expected[i] - result[i] > 1).sum(), 0)
347
        # End for
348
    # end validate_image_and_label
349

350
    @given(size_tuple=st.tuples(
351
        st.integers(min_value=8, max_value=4096),
352
        st.integers(min_value=8, max_value=4096)).flatmap(lambda t: st.tuples(
353
            st.just(t[0]), st.just(t[1]),
354
            st.just(min(t[0] - 6, t[1] - 4)),
355
            st.integers(min_value=1, max_value=min(t[0] - 6, t[1] - 4)))),
356
        means=st.tuples(st.integers(min_value=0, max_value=255),
357
                        st.integers(min_value=0, max_value=255),
358
                        st.integers(min_value=0, max_value=255)),
359
        stds=st.tuples(st.floats(min_value=1, max_value=10),
360
                       st.floats(min_value=1, max_value=10),
361
                       st.floats(min_value=1, max_value=10)),
362
        label_type=st.integers(0, 3),
363
        num_labels=st.integers(min_value=8, max_value=4096),
364
        is_test=st.integers(min_value=0, max_value=1),
365
        scale_jitter_type=st.integers(min_value=0, max_value=1),
366
        color_jitter=st.integers(min_value=0, max_value=1),
367
        color_lighting=st.integers(min_value=0, max_value=1),
368
        **hu.gcs)
369
    @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
370
    def test_imageinput(
371
            self, size_tuple, means, stds, label_type,
372
            num_labels, is_test, scale_jitter_type, color_jitter, color_lighting,
373
            gc, dc):
374
        def validator(expected_images, device_option, count_images):
375
            self.validate_image_and_label(
376
                expected_images, device_option, count_images, label_type,
377
                is_test, scale_jitter_type, color_jitter, color_lighting)
378
        # End validator
379
        run_test(
380
            size_tuple, means, stds, label_type, num_labels, is_test,
381
            scale_jitter_type, color_jitter, color_lighting, dc, validator)
382
    # End test_imageinput
383

384
    @given(size_tuple=st.tuples(
385
        st.integers(min_value=8, max_value=4096),
386
        st.integers(min_value=8, max_value=4096)).flatmap(lambda t: st.tuples(
387
            st.just(t[0]), st.just(t[1]),
388
            st.just(min(t[0] - 6, t[1] - 4)),
389
            st.integers(min_value=1, max_value=min(t[0] - 6, t[1] - 4)))),
390
        means=st.tuples(st.integers(min_value=0, max_value=255),
391
                        st.integers(min_value=0, max_value=255),
392
                        st.integers(min_value=0, max_value=255)),
393
        stds=st.tuples(st.floats(min_value=1, max_value=10),
394
                       st.floats(min_value=1, max_value=10),
395
                       st.floats(min_value=1, max_value=10)),
396
        label_type=st.integers(0, 3),
397
        num_labels=st.integers(min_value=8, max_value=4096),
398
        is_test=st.integers(min_value=0, max_value=1),
399
        scale_jitter_type=st.integers(min_value=0, max_value=1),
400
        color_jitter=st.integers(min_value=0, max_value=1),
401
        color_lighting=st.integers(min_value=0, max_value=1),
402
        output1=st.floats(min_value=1, max_value=10),
403
        output2_size=st.integers(min_value=2, max_value=10),
404
        **hu.gcs)
405
    @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
406
    def test_imageinput_with_additional_outputs(
407
            self, size_tuple, means, stds, label_type,
408
            num_labels, is_test, scale_jitter_type, color_jitter, color_lighting,
409
            output1, output2_size, gc, dc):
410
        def validator(expected_images, device_option, count_images):
411
            self.validate_image_and_label(
412
                expected_images, device_option, count_images, label_type,
413
                is_test, scale_jitter_type, color_jitter, color_lighting)
414

415
            output1_result = workspace.FetchBlob('output1')
416
            output2_result = workspace.FetchBlob('output2')
417

418
            for i in range(count_images):
419
                self.assertEqual(output1_result[i], expected_images[i][2])
420
                self.assertEqual(
421
                    (output2_result[i] - expected_images[i][3] > 0).sum(), 0)
422
            # End for
423
        # End validator
424
        run_test(
425
            size_tuple, means, stds, label_type, num_labels, is_test,
426
            scale_jitter_type, color_jitter, color_lighting, dc,
427
            validator, output1, output2_size)
428
    # End test_imageinput
429

430

431
if __name__ == '__main__':
432
    import unittest
433
    unittest.main()
434

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.