pytorch

Форк
0
1007 строк · 31.7 Кб
1

2

3
import collections
4
import functools
5
import unittest
6

7
import caffe2.python._import_c_extension as C
8
import caffe2.python.hip_test_util as hiputl
9
import caffe2.python.hypothesis_test_util as hu
10
import caffe2.python.serialized_test.serialized_test_util as serial
11
import hypothesis.strategies as st
12
import numpy as np
13
from caffe2.proto import caffe2_pb2
14
from caffe2.python import brew, core, utils, workspace
15
from caffe2.python.model_helper import ModelHelper
16
from hypothesis import assume, given, settings
17

18

19
def _cudnn_supports(dilation=False, nhwc=False, backward=False):
20
    """Return True if cuDNN supports this configuration."""
21
    v = workspace.GetCuDNNVersion()
22
    if backward:
23
        if nhwc:
24
            # nhwc isn't supported in backward ops.
25
            return False
26
    else:
27
        # Forward mode.
28
        if dilation and v < 6000:
29
            # Dilation not supported until v6
30
            return False
31
        if dilation and nhwc:
32
            # Dilation and NHWC not supported together
33
            return False
34
    return True
35

36

37
def _cudnn_convolution_algo_count(direction):
38
    try:
39
        if direction == "fwd":
40
            return st.integers(0, C.cudnn_convolution_fwd_algo_count - 1)
41
        elif direction == "dgrad":
42
            return st.integers(0, C.cudnn_convolution_bwd_data_algo_count - 1)
43
        elif direction == "wgrad":
44
            return st.integers(0, C.cudnn_convolution_bwd_filter_algo_count - 1)
45
        else:
46
            assert False
47
    except Exception:
48
        return st.sampled_from([-1])
49

50

51
class TestConvolution(serial.SerializedTestCase):
52
    # CUDNN does NOT support different padding values and we skip it
53
    @given(
54
        op_type=st.sampled_from(["Conv", "Conv2D"]),
55
        stride_h=st.integers(1, 3),
56
        stride_w=st.integers(1, 3),
57
        pad_t=st.integers(0, 3),
58
        pad_l=st.integers(0, 3),
59
        pad_b=st.integers(0, 3),
60
        pad_r=st.integers(0, 3),
61
        kernel=st.integers(3, 5),
62
        size=st.integers(1, 8),
63
        input_channels=st.integers(1, 3),
64
        output_channels=st.integers(1, 3),
65
        batch_size=st.integers(0, 3),
66
        group=st.integers(1, 2),
67
        order=st.sampled_from(["NCHW", "NHWC"]),
68
        engine=st.sampled_from(["", "EIGEN"]),
69
        shared_buffer=st.booleans(),
70
        use_bias=st.booleans(),
71
        **hu.gcs
72
    )
73
    @settings(deadline=None, max_examples=50)
74
    def test_convolution_separate_stride_pad_gradients(
75
        self,
76
        op_type,
77
        stride_h,
78
        stride_w,
79
        pad_t,
80
        pad_l,
81
        pad_b,
82
        pad_r,
83
        kernel,
84
        size,
85
        input_channels,
86
        output_channels,
87
        batch_size,
88
        group,
89
        order,
90
        engine,
91
        shared_buffer,
92
        use_bias,
93
        gc,
94
        dc,
95
    ):
96
        # TODO: Group conv in NHWC not implemented for GPU yet.
97
        assume(group == 1 or order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
98
        if group != 1 and order == "NHWC":
99
            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
100
        # Group conv not implemented with EIGEN engine.
101
        assume(group == 1 or engine != "EIGEN")
102

103
        input_channels *= group
104
        output_channels *= group
105

106
        op = core.CreateOperator(
107
            op_type,
108
            ["X", "w", "b"] if use_bias else ["X", "w"],
109
            ["Y"],
110
            stride_h=stride_h,
111
            stride_w=stride_w,
112
            pad_t=pad_t,
113
            pad_l=pad_l,
114
            pad_b=pad_b,
115
            pad_r=pad_r,
116
            kernel=kernel,
117
            group=group,
118
            order=order,
119
            engine=engine,
120
            shared_buffer=int(shared_buffer),
121
        )
122
        X = (
123
            np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
124
            - 0.5
125
        )
126
        w = (
127
            np.random.rand(
128
                output_channels, kernel, kernel, int(input_channels / group)
129
            ).astype(np.float32)
130
            - 0.5
131
        )
132
        b = np.random.rand(output_channels).astype(np.float32) - 0.5
133
        if order == "NCHW":
134
            X = utils.NHWC2NCHW(X)
135
            w = utils.NHWC2NCHW(w)
136

137
        inputs = [X, w, b] if use_bias else [X, w]
138

139
        # Error handling path.
140
        if size + pad_r + pad_l < kernel or size + pad_t + pad_b < kernel:
141
            with self.assertRaises(RuntimeError):
142
                self.assertDeviceChecks(dc, op, inputs, [0])
143
            return
144

145
        self.assertDeviceChecks(dc, op, inputs, [0])
146
        for i in range(len(inputs)):
147
            self.assertGradientChecks(gc, op, inputs, i, [0])
148

149
    # CUDNN does NOT support different padding values and we skip it
150
    @given(
151
        op_type=st.sampled_from(["Conv", "Conv2D"]),
152
        stride_h=st.integers(1, 3),
153
        stride_w=st.integers(1, 3),
154
        pad_t=st.integers(0, 3),
155
        pad_l=st.integers(0, 3),
156
        pad_b=st.integers(0, 3),
157
        pad_r=st.integers(0, 3),
158
        kernel=st.integers(1, 5),
159
        size=st.integers(7, 10),
160
        input_channels=st.integers(1, 8),
161
        output_channels=st.integers(1, 8),
162
        batch_size=st.integers(0, 3),
163
        engine=st.sampled_from(["", "EIGEN"]),
164
        use_bias=st.booleans(),
165
        **hu.gcs
166
    )
167
    @settings(deadline=None)
168
    def test_convolution_separate_stride_pad_layout(
169
        self,
170
        op_type,
171
        stride_h,
172
        stride_w,
173
        pad_t,
174
        pad_l,
175
        pad_b,
176
        pad_r,
177
        kernel,
178
        size,
179
        input_channels,
180
        output_channels,
181
        batch_size,
182
        engine,
183
        use_bias,
184
        gc,
185
        dc,
186
    ):
187
        X = (
188
            np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
189
            - 0.5
190
        )
191
        w = (
192
            np.random.rand(output_channels, kernel, kernel, input_channels).astype(
193
                np.float32
194
            )
195
            - 0.5
196
        )
197
        b = np.random.rand(output_channels).astype(np.float32) - 0.5
198
        outputs = {}
199
        for order in ["NCHW", "NHWC"]:
200
            op = core.CreateOperator(
201
                op_type,
202
                ["X", "w", "b"] if use_bias else ["X", "w"],
203
                ["Y"],
204
                stride_h=stride_h,
205
                stride_w=stride_w,
206
                kernel=kernel,
207
                pad_t=pad_t,
208
                pad_l=pad_l,
209
                pad_b=pad_b,
210
                pad_r=pad_r,
211
                order=order,
212
                engine=engine,
213
                device_option=gc,
214
            )
215
            if order == "NCHW":
216
                X_f = utils.NHWC2NCHW(X)
217
                w_f = utils.NHWC2NCHW(w)
218
            else:
219
                X_f = X
220
                w_f = w
221
            self.ws.create_blob("X").feed(X_f, device_option=gc)
222
            self.ws.create_blob("w").feed(w_f, device_option=gc)
223
            self.ws.create_blob("b").feed(b, device_option=gc)
224
            self.ws.run(op)
225
            outputs[order] = self.ws.blobs["Y"].fetch()
226
        np.testing.assert_allclose(
227
            outputs["NCHW"], utils.NHWC2NCHW(outputs["NHWC"]), atol=1e-4, rtol=1e-4
228
        )
229

230
    @given(
231
        op_type=st.sampled_from(["Conv", "Conv2D"]),
232
        stride=st.integers(1, 3),
233
        pad=st.integers(0, 3),
234
        kernel=st.integers(1, 5),
235
        dilation=st.integers(1, 3),
236
        size=st.integers(7, 10),
237
        input_channels=st.integers(1, 8),
238
        output_channels=st.integers(1, 8),
239
        batch_size=st.integers(0, 3),
240
        group=st.integers(1, 2),
241
        order=st.sampled_from(["NCHW", "NHWC"]),
242
        engine=st.sampled_from(["", "CUDNN", "MKLDNN"]),
243
        use_bias=st.booleans(),
244
        force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
245
        force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
246
        force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
247
        **hu.gcs
248
    )
249
    @settings(max_examples=20, deadline=None)
250
    def test_convolution_gradients(
251
        self,
252
        op_type,
253
        stride,
254
        pad,
255
        kernel,
256
        dilation,
257
        size,
258
        input_channels,
259
        output_channels,
260
        batch_size,
261
        group,
262
        order,
263
        engine,
264
        use_bias,
265
        force_algo_fwd,
266
        force_algo_dgrad,
267
        force_algo_wgrad,
268
        gc,
269
        dc,
270
    ):
271
        # TODO: Group conv in NHWC not implemented for GPU yet.
272
        assume(
273
            group == 1
274
            or (order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
275
            and engine != "MKLDNN"
276
        )
277
        if group != 1 and order == "NHWC":
278
            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
279

280
        input_channels *= group
281
        output_channels *= group
282
        dkernel = dilation * (kernel - 1) + 1
283

284
        if engine == "CUDNN":
285
            if hiputl.run_in_hip(gc, dc):
286
                assume((order == "NCHW") and not (dilation > 1 and group > 1))
287
            else:
288
                assume(
289
                    _cudnn_supports(
290
                        dilation=(dilation > 1), nhwc=(order == "NHWC"), backward=True
291
                    )
292
                )
293

294
        assume(engine != "MKLDNN" or use_bias is True)
295

296
        op = core.CreateOperator(
297
            op_type,
298
            ["X", "w", "b"] if use_bias else ["X", "w"],
299
            ["Y"],
300
            stride=stride,
301
            kernel=kernel,
302
            dilation=dilation,
303
            pad=pad,
304
            group=group,
305
            order=order,
306
            engine=engine,
307
            force_algo_fwd=force_algo_fwd,
308
            force_algo_dgrad=force_algo_dgrad,
309
            force_algo_wgrad=force_algo_wgrad,
310
        )
311
        X = (
312
            np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
313
            - 0.5
314
        )
315
        w = (
316
            np.random.rand(
317
                output_channels, kernel, kernel, int(input_channels / group)
318
            ).astype(np.float32)
319
            - 0.5
320
        )
321
        b = np.random.rand(output_channels).astype(np.float32) - 0.5
322
        if order == "NCHW":
323
            X = utils.NHWC2NCHW(X)
324
            w = utils.NHWC2NCHW(w)
325

326
        inputs = [X, w, b] if use_bias else [X, w]
327
        # Error handling path.
328
        if size + pad + pad < dkernel or size + pad + pad < dkernel:
329
            with self.assertRaises(RuntimeError):
330
                self.assertDeviceChecks(dc, op, inputs, [0])
331
            return
332

333
        try:
334
            self.assertDeviceChecks(dc, op, inputs, [0])
335
        except RuntimeError as e:
336
            es = str(e)
337
            # CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM should always have
338
            # implementation
339
            if (
340
                "status == CUDNN_STATUS_SUCCESS" not in es
341
                or "CUDNN_STATUS_NOT_SUPPORTED" not in es
342
                or force_algo_fwd == 0
343
            ):
344
                raise e
345

346
        for i in range(len(inputs)):
347
            try:
348
                self.assertGradientChecks(gc, op, inputs, i, [0])
349
            except RuntimeError as e:
350
                es = str(e)
351
                if (
352
                    "status == CUDNN_STATUS_SUCCESS" not in es
353
                    or "CUDNN_STATUS_NOT_SUPPORTED" not in es
354
                ):
355
                    raise e
356

357
    def _nd_convolution(
358
        self,
359
        n,
360
        input_channels_per_group,
361
        output_channels_per_group,
362
        batch_size,
363
        stride,
364
        size,
365
        kernel,
366
        dilation,
367
        pad,
368
        group,
369
        order,
370
        use_bias,
371
        engine,
372
        force_algo_fwd,
373
        force_algo_dgrad,
374
        force_algo_wgrad,
375
        gc,
376
        dc,
377
    ):
378
        # TODO: Group conv in NHWC not implemented for GPU yet.
379
        # TODO: Group 1D conv in NCHW not implemented for GPU yet.
380
        assume(
381
            group == 1
382
            or (n != 1 and order == "NCHW")
383
            or gc.device_type == caffe2_pb2.CPU
384
        )
385
        if group != 1 and (n == 1 or order == "NHWC"):
386
            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
387

388
        input_channels = group * input_channels_per_group
389
        output_channels = group * output_channels_per_group
390

391
        dkernel = dilation * (kernel - 1) + 1
392
        for op_type in ["Conv", "Conv" + str(n) + "D"]:
393
            op = core.CreateOperator(
394
                op_type,
395
                ["X", "w", "b"] if use_bias else ["X", "w"],
396
                ["Y"],
397
                strides=[stride] * n,
398
                kernels=[kernel] * n,
399
                dilations=[dilation] * n,
400
                pads=[pad] * n * 2,
401
                group=group,
402
                order=order,
403
                engine=engine,
404
                force_algo_fwd=force_algo_fwd,
405
                force_algo_dgrad=force_algo_dgrad,
406
                force_algo_wgrad=force_algo_wgrad,
407
            )
408

409
            input_dims = [batch_size, input_channels]
410
            input_dims.extend([size] * n)
411
            filter_dims = [output_channels, input_channels // group]
412
            filter_dims.extend([kernel] * n)
413

414
            X = np.random.rand(*input_dims).astype(np.float32) - 0.5
415
            w = np.random.rand(*filter_dims).astype(np.float32) - 0.5
416
            b = np.random.rand(output_channels).astype(np.float32) - 0.5
417
            if order == "NHWC":
418
                X = utils.NCHW2NHWC(X)
419
                w = utils.NCHW2NHWC(w)
420

421
            inputs = [X, w, b] if use_bias else [X, w]
422

423
            if size + pad + pad < dkernel or size + pad + pad < dkernel:
424
                with self.assertRaises(RuntimeError):
425
                    self.assertDeviceChecks(dc, op, inputs, [0])
426
                return
427

428
            self.assertDeviceChecks(dc, op, inputs, [0])
429
            for i in range(len(inputs)):
430
                self.assertGradientChecks(gc, op, inputs, i, [0])
431

432
    @given(
433
        input_channels=st.integers(1, 3),
434
        output_channels=st.integers(1, 2),
435
        batch_size=st.integers(0, 3),
436
        stride=st.integers(1, 3),
437
        size=st.integers(7, 10),
438
        kernel=st.integers(1, 2),
439
        dilation=st.integers(1, 3),
440
        pad=st.integers(0, 3),
441
        group=st.integers(1, 2),
442
        order=st.sampled_from(["NCHW", "NHWC"]),
443
        use_bias=st.booleans(),
444
        engine=st.sampled_from(["", "CUDNN"]),
445
        force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
446
        force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
447
        force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
448
        **hu.gcs
449
    )
450
    @settings(deadline=10000)
451
    def test_1d_convolution(
452
        self,
453
        input_channels,
454
        output_channels,
455
        batch_size,
456
        stride,
457
        size,
458
        kernel,
459
        dilation,
460
        pad,
461
        group,
462
        order,
463
        use_bias,
464
        engine,
465
        force_algo_fwd,
466
        force_algo_dgrad,
467
        force_algo_wgrad,
468
        gc,
469
        dc,
470
    ):
471
        if hiputl.run_in_hip(gc, dc):
472
            # currently miopen only supports 2d conv
473
            assume(engine != "CUDNN")  # CUDNN is aliased to MIOPEN for HIP
474
        # TODO: 1D conv in NHWC not implemented for GPU yet.
475
        assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
476
        if order == "NHWC":
477
            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
478

479
        self._nd_convolution(
480
            1,
481
            input_channels,
482
            output_channels,
483
            batch_size,
484
            stride,
485
            size,
486
            kernel,
487
            dilation,
488
            pad,
489
            group,
490
            order,
491
            use_bias,
492
            engine,
493
            force_algo_fwd,
494
            force_algo_dgrad,
495
            force_algo_wgrad,
496
            gc,
497
            dc,
498
        )
499

500
    @given(
501
        input_channels=st.integers(1, 2),
502
        output_channels=st.integers(1, 2),
503
        batch_size=st.integers(0, 2),
504
        stride=st.integers(1, 2),
505
        size=st.integers(4, 5),
506
        kernel=st.integers(1, 2),
507
        dilation=st.integers(1, 2),
508
        pad=st.integers(0, 2),
509
        group=st.integers(1, 2),
510
        order=st.sampled_from(["NCHW", "NHWC"]),
511
        use_bias=st.booleans(),
512
        engine=st.sampled_from(["", "MIOPEN"]),  # TODO: add "CUDNN"
513
        force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
514
        force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
515
        force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
516
        **hu.gcs
517
    )
518
    @settings(max_examples=20, deadline=None)
519
    def test_3d_convolution(
520
        self,
521
        input_channels,
522
        output_channels,
523
        batch_size,
524
        stride,
525
        size,
526
        kernel,
527
        dilation,
528
        pad,
529
        group,
530
        order,
531
        use_bias,
532
        engine,
533
        force_algo_fwd,
534
        force_algo_dgrad,
535
        force_algo_wgrad,
536
        gc,
537
        dc,
538
    ):
539
        # TODO: 3D conv in NHWC not implemented for GPU yet.
540
        assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
541
        if order == "NHWC":
542
            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
543
        self._nd_convolution(
544
            3,
545
            input_channels,
546
            output_channels,
547
            batch_size,
548
            stride,
549
            size,
550
            kernel,
551
            dilation,
552
            pad,
553
            group,
554
            order,
555
            use_bias,
556
            engine,
557
            force_algo_fwd,
558
            force_algo_dgrad,
559
            force_algo_wgrad,
560
            gc,
561
            dc,
562
        )
563

564
    @given(
565
        op_type=st.sampled_from(["Conv", "Conv3D"]),
566
        batch_size=st.integers(0, 2),
567
        stride=st.integers(1, 2),
568
        size=st.integers(3, 5),
569
        kernel=st.integers(1, 2),
570
        dilation=st.integers(1, 2),
571
        pad=st.integers(0, 2),
572
        use_bias=st.booleans(),
573
        force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
574
        force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
575
        force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
576
        **hu.gcs_no_hip
577
    )  # MIOPEN doesn't support 3D conv yet
578
    @settings(deadline=10000)
579
    def test_3d_convolution_cudnn_nchw(
580
        self,
581
        op_type,
582
        batch_size,
583
        stride,
584
        size,
585
        kernel,
586
        dilation,
587
        pad,
588
        use_bias,
589
        force_algo_fwd,
590
        force_algo_dgrad,
591
        force_algo_wgrad,
592
        gc,
593
        dc,
594
    ):
595
        input_channels = 1
596
        output_channels = 1
597
        n = 3
598
        dkernel = dilation * (kernel - 1) + 1
599
        order = "NCHW"
600

601
        op = core.CreateOperator(
602
            op_type,
603
            ["X", "w", "b"] if use_bias else ["X", "w"],
604
            ["Y"],
605
            strides=[stride] * n,
606
            kernels=[kernel] * n,
607
            dilations=[dilation] * n,
608
            pads=[pad] * n * 2,
609
            order=order,
610
            engine="CUDNN",
611
            force_algo_fwd=force_algo_fwd,
612
            force_algo_dgrad=force_algo_dgrad,
613
            force_algo_wgrad=force_algo_wgrad,
614
        )
615

616
        input_dims = [batch_size, input_channels]
617
        input_dims.extend([size] * n)
618
        filter_dims = [output_channels, input_channels]
619
        filter_dims.extend([kernel] * n)
620
        X = np.random.rand(*input_dims).astype(np.float32) - 0.5
621
        w = np.random.rand(*filter_dims).astype(np.float32) - 0.5
622
        b = np.random.rand(output_channels).astype(np.float32) - 0.5
623

624
        inputs = [X, w, b] if use_bias else [X, w]
625

626
        if size + pad + pad < dkernel or size + pad + pad < dkernel:
627
            with self.assertRaises(RuntimeError):
628
                self.assertDeviceChecks(dc, op, inputs, [0])
629
            return
630

631
        try:
632
            self.assertDeviceChecks(dc, op, inputs, [0])
633
        except RuntimeError as e:
634
            es = str(e)
635
            # CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM should always have
636
            # implementation
637
            if (
638
                "status == CUDNN_STATUS_SUCCESS" not in es
639
                or "CUDNN_STATUS_NOT_SUPPORTED" not in es
640
                or force_algo_fwd == 0
641
            ):
642
                raise e
643

644
        for i in range(len(inputs)):
645
            try:
646
                self.assertGradientChecks(gc, op, inputs, i, [0])
647
            except RuntimeError as e:
648
                es = str(e)
649
                if (
650
                    "status == CUDNN_STATUS_SUCCESS" not in es
651
                    or "CUDNN_STATUS_NOT_SUPPORTED" not in es
652
                ):
653
                    raise e
654

655
    @given(
656
        op_type=st.sampled_from(["Conv", "Conv2D"]),
657
        stride=st.integers(1, 3),
658
        pad=st.integers(0, 3),
659
        kernel=st.integers(1, 5),
660
        dilation=st.integers(1, 3),
661
        size=st.integers(7, 10),
662
        input_channels=st.integers(1, 8),
663
        output_channels=st.integers(1, 8),
664
        batch_size=st.integers(0, 3),
665
        use_bias=st.booleans(),
666
        **hu.gcs
667
    )
668
    @settings(deadline=None, max_examples=50)
669
    def test_convolution_layout(
670
        self,
671
        op_type,
672
        stride,
673
        pad,
674
        kernel,
675
        dilation,
676
        size,
677
        input_channels,
678
        output_channels,
679
        batch_size,
680
        use_bias,
681
        gc,
682
        dc,
683
    ):
684
        assume(size >= dilation * (kernel - 1) + 1)
685

686
        X = (
687
            np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
688
            - 0.5
689
        )
690
        w = (
691
            np.random.rand(output_channels, kernel, kernel, input_channels).astype(
692
                np.float32
693
            )
694
            - 0.5
695
        )
696
        b = np.random.rand(output_channels).astype(np.float32) - 0.5
697
        Output = collections.namedtuple("Output", ["Y", "engine", "order"])
698
        outputs = []
699

700
        for order in ["NCHW", "NHWC"]:
701
            engine_list = [""]
702
            if hiputl.run_in_hip(gc, dc):
703
                if order == "NCHW":
704
                    engine_list.append("MIOPEN")
705
            else:
706
                if _cudnn_supports(dilation=(dilation > 1), nhwc=(order == "NHWC")):
707
                    engine_list.append("CUDNN")
708

709
            for engine in engine_list:
710
                op = core.CreateOperator(
711
                    op_type,
712
                    ["X", "w", "b"] if use_bias else ["X", "w"],
713
                    ["Y"],
714
                    stride=stride,
715
                    kernel=kernel,
716
                    dilation=dilation,
717
                    pad=pad,
718
                    order=order,
719
                    engine=engine,
720
                    device_option=gc,
721
                    exhaustive_search=True,
722
                )
723
                if order == "NCHW":
724
                    X_f = utils.NHWC2NCHW(X)
725
                    w_f = utils.NHWC2NCHW(w)
726
                else:
727
                    X_f = X
728
                    w_f = w
729
                self.assertDeviceChecks(
730
                    dc, op, [X_f, w_f, b] if use_bias else [X_f, w_f], [0]
731
                )
732
                self.ws.create_blob("X").feed(X_f, device_option=gc)
733
                self.ws.create_blob("w").feed(w_f, device_option=gc)
734
                self.ws.create_blob("b").feed(b, device_option=gc)
735
                self.ws.run(op)
736
                outputs.append(
737
                    Output(Y=self.ws.blobs["Y"].fetch(), engine=engine, order=order)
738
                )
739

740
        def canonical(o):
741
            if o.order == "NHWC":
742
                return utils.NHWC2NCHW(o.Y)
743
            else:
744
                return o.Y
745

746
        for o in outputs:
747
            np.testing.assert_allclose(
748
                canonical(outputs[0]), canonical(o), atol=1e-4, rtol=1e-4
749
            )
750

751
    @given(
752
        num_workers=st.integers(1, 4),
753
        net_type=st.sampled_from(
754
            ["simple", "dag"]
755
            + (
756
                ["async_dag"]
757
                if workspace.has_gpu_support
758
                else []
759
            )
760
        ),
761
        engine=st.sampled_from(["CUDNN", ""]),
762
        **hu.gcs_no_hip
763
    )
764
    @settings(deadline=None)
765
    def test_convolution_sync(self, net_type, num_workers, engine, gc, dc):
766
        m = ModelHelper(name="test_model")
767
        n = 1
768
        d = 2
769
        depth = 3
770
        iters = 5
771
        h = 5
772
        w = 5
773
        workspace.ResetWorkspace()
774

775
        use_cudnn = engine == "CUDNN"
776

777
        np.random.seed(1701)
778
        # Build a binary tree of conv layers, summing at each node.
779
        for i in reversed(range(depth)):
780
            for j in range(2 ** i):
781
                bottom_1 = "{}_{}".format(i + 1, 2 * j)
782
                bottom_2 = "{}_{}".format(i + 1, 2 * j + 1)
783
                mid_1 = "{}_{}_m".format(i + 1, 2 * j)
784
                mid_2 = "{}_{}_m".format(i + 1, 2 * j + 1)
785
                top = "{}_{}".format(i, j)
786
                w1, b1, w2, b2 = np.random.randn(4).tolist()
787
                brew.conv(
788
                    m,
789
                    bottom_1,
790
                    mid_1,
791
                    dim_in=d,
792
                    dim_out=d,
793
                    kernel=3,
794
                    weight_init=("ConstantFill", {"value": w1}),
795
                    bias_init=("ConstantFill", {"value": b1}),
796
                    cudnn_state=np.random.randint(0, 3),
797
                    stride=1,
798
                    pad=1,
799
                    deterministic=1,
800
                    use_cudnn=use_cudnn,
801
                    engine=engine,
802
                )
803
                brew.conv(
804
                    m,
805
                    bottom_2,
806
                    mid_2,
807
                    dim_in=d,
808
                    dim_out=d,
809
                    kernel=3,
810
                    stride=1,
811
                    pad=1,
812
                    weight_init=("ConstantFill", {"value": w2}),
813
                    bias_init=("ConstantFill", {"value": b2}),
814
                    deterministic=1,
815
                    cudnn_state=np.random.randint(0, 3),
816
                    use_cudnn=use_cudnn,
817
                    engine=engine,
818
                )
819
                m.net.Sum([mid_1, mid_2], top)
820

821
        m.net.Flatten(["0_0"], ["0_0_flat"])
822
        m.net.SquaredL2Distance(["0_0_flat", "label"], "xent")
823
        m.net.AveragedLoss("xent", "loss")
824
        input_to_grad = m.AddGradientOperators(["loss"])
825
        m.Proto().device_option.CopyFrom(gc)
826
        m.param_init_net.Proto().device_option.CopyFrom(gc)
827
        m.Proto().type = net_type
828
        m.Proto().num_workers = num_workers
829
        self.ws.run(m.param_init_net)
830

831
        def run():
832
            import numpy as np
833

834
            np.random.seed(1701)
835
            input_blobs = ["{}_{}".format(depth, j) for j in range(2 ** depth)]
836
            for input_blob in input_blobs:
837
                self.ws.create_blob(input_blob).feed(
838
                    np.random.randn(n, d, h, w).astype(np.float32), device_option=gc
839
                )
840
                self.ws.create_blob("label").feed(
841
                    np.random.randn(n, d * h * w).astype(np.float32), device_option=gc
842
                )
843
            self.ws.run(m.net)
844
            gradients = [
845
                self.ws.blobs[str(input_to_grad[input_blob])].fetch()
846
                for input_blob in input_blobs
847
            ]
848
            return gradients
849

850
        outputs = [run() for _ in range(iters)]
851
        for output in outputs[1:]:
852
            np.testing.assert_array_equal(outputs[0], output)
853
            np.testing.assert_allclose(
854
                np.sum(np.square(output)), 1763719461732352.0, rtol=1e-5
855
            )
856

857
    def test_use_cudnn_engine_interactions(self):
858
        """Make sure the use_cudnn and engine kwargs work as expected."""
859
        for model_default in [None, True, False]:
860
            arg_scope = {}
861
            if model_default is not None:
862
                arg_scope["use_cudnn"] = model_default
863
            else:
864
                model_default = True  # the default
865

866
            model = ModelHelper(arg_scope=arg_scope)
867
            self.assertEqual(model.arg_scope["use_cudnn"], model_default)
868
            f = functools.partial(brew.conv, model, "conv_in", "conv_out", 10, 10, 5)
869

870
            for op_cudnn in [None, True, False]:
871
                for op_engine in [None, "", "CUDNN"]:
872
                    kwargs = {}
873
                    if op_cudnn is not None:
874
                        kwargs["use_cudnn"] = op_cudnn
875
                    else:
876
                        op_cudnn = False  # the default
877
                    if op_engine is not None:
878
                        kwargs["engine"] = op_engine
879

880
                    calculated_cudnn = kwargs.get("use_cudnn", model_default)
881
                    expected_engine = kwargs.get(
882
                        "engine", "CUDNN" if calculated_cudnn else ""
883
                    )
884

885
                    if (calculated_cudnn is False and op_engine == "CUDNN") or (
886
                        calculated_cudnn is True and op_engine == ""
887
                    ):
888
                        with self.assertRaises(ValueError):
889
                            f(**kwargs)
890
                    else:
891
                        f(**kwargs)
892
                        self.assertEqual(model.Proto().op[-1].engine, expected_engine)
893

894
    @given(
895
        op_type=st.sampled_from(["Conv", "Conv2D"]),
896
        N=st.integers(0, 3),
897
        G=st.integers(1, 3),
898
        DX=st.integers(1, 3),
899
        DY=st.integers(1, 3),
900
        H=st.integers(1, 3),
901
        W=st.integers(1, 3),
902
        use_bias=st.booleans(),
903
        order=st.sampled_from(["NCHW", "NHWC"]),
904
        force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
905
        force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
906
        force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
907
        **hu.gcs
908
    )
909
    @settings(deadline=10000)
910
    def test_1x1_conv(
911
        self,
912
        op_type,
913
        N,
914
        G,
915
        DX,
916
        DY,
917
        H,
918
        W,
919
        use_bias,
920
        order,
921
        force_algo_fwd,
922
        force_algo_dgrad,
923
        force_algo_wgrad,
924
        gc,
925
        dc,
926
    ):
927
        if hiputl.run_in_hip(gc, dc):
928
            assume(order == "NCHW")
929
        if order == "NHWC":
930
            G = 1
931

932
        C = G * DX
933
        M = G * DY
934

935
        op = core.CreateOperator(
936
            op_type,
937
            ["X", "filter", "bias"] if use_bias else ["X", "filter"],
938
            ["Y"],
939
            stride_h=1,
940
            stride_w=1,
941
            pad_t=0,
942
            pad_l=0,
943
            pad_b=0,
944
            pad_r=0,
945
            kernel=1,
946
            order=order,
947
            group=G,
948
            force_algo_fwd=force_algo_fwd,
949
            force_algo_dgrad=force_algo_dgrad,
950
            force_algo_wgrad=force_algo_wgrad,
951
        )
952

953
        if order == "NCHW":
954
            X = np.random.randn(N, C, H, W).astype(np.float32)
955
            filter = np.random.randn(M, DX, 1, 1).astype(np.float32)
956
        else:
957
            X = np.random.randn(N, H, W, C).astype(np.float32)
958
            filter = np.random.randn(M, 1, 1, DX).astype(np.float32)
959
        bias = np.random.randn(M).astype(np.float32)
960
        inputs = [X, filter, bias] if use_bias else [X, filter]
961

962
        def conv_1x1_nchw_ref(X, filter, bias=None):
963
            if N == 0:
964
                Y = np.zeros(shape=(N, M, H, W), dtype=np.float32)
965
                return [Y]
966

967
            X = X.reshape(N, G, DX, -1)
968
            filter = filter.reshape(G, DY, DX)
969
            Y = np.zeros(shape=(N, G, DY, H * W), dtype=np.float32)
970
            for i in range(N):
971
                for j in range(G):
972
                    Y[i, j, :, :] = np.dot(filter[j, :, :], X[i, j, :, :])
973
            Y = Y.reshape(N, M, H, W)
974
            if bias is not None:
975
                bias = bias.reshape(1, M, 1, 1)
976
                Y = np.add(Y, bias)
977
            return [Y]
978

979
        def conv_1x1_nhwc_ref(X, filter, bias=None):
980
            if N == 0:
981
                Y = np.zeros(shape=(N, H, W, M), dtype=np.float32)
982
                return [Y]
983

984
            X = X.reshape(N, -1, G, DX)
985
            filter = filter.reshape(G, DY, DX)
986
            Y = np.zeros(shape=(N, H * W, G, DY), dtype=np.float32)
987
            for i in range(N):
988
                for j in range(G):
989
                    Y[i, :, j, :] = np.dot(X[i, :, j, :], filter[j, :, :].transpose())
990
            Y = Y.reshape(N, H, W, M)
991
            if bias is not None:
992
                bias = bias.reshape(1, 1, 1, M)
993
                Y = np.add(Y, bias)
994
            return [Y]
995

996
        if order == "NCHW":
997
            conv_1x1_ref = conv_1x1_nchw_ref
998
        else:
999
            conv_1x1_ref = conv_1x1_nhwc_ref
1000
        self.assertReferenceChecks(
1001
            device_option=gc, op=op, inputs=inputs, reference=conv_1x1_ref
1002
        )
1003
        self.assertDeviceChecks(dc, op, inputs, [0])
1004
        for i in range(len(inputs)):
1005
            self.assertGradientChecks(gc, op, inputs, i, [0])
1006

1007

1008
if __name__ == "__main__":
1009
    unittest.main()
1010

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.