pytorch

Форк
0
/
test_convolution.py 
4039 строк · 150.5 Кб
1
# Owner(s): ["module: nn"]
2
import itertools
3
import math
4
import unittest
5
import warnings
6
from itertools import product
7

8
import torch
9
import torch.autograd.forward_ad as fwAD
10
import torch.backends.cudnn as cudnn
11
import torch.nn as nn
12
import torch.nn.functional as F
13
from torch.testing import make_tensor
14
from torch.testing._internal.common_cuda import (
15
    TEST_CUDA,
16
    TEST_CUDNN,
17
    tf32_is_not_fp32,
18
    tf32_on_and_off,
19
)
20
from torch.testing._internal.common_device_type import (
21
    disablecuDNN,
22
    disableMkldnn,
23
    dtypes,
24
    dtypesIfCUDA,
25
    instantiate_device_type_tests,
26
    largeTensorTest,
27
    onlyCPU,
28
    onlyCUDA,
29
    onlyNativeDeviceTypes,
30
    precisionOverride,
31
    skipCPUIfNoMkldnn,
32
    skipCUDAIfCudnnVersionLessThan,
33
    skipCUDAIfMiopen,
34
    skipCUDAIfNoCudnn,
35
    skipCUDAIfNoMiopen,
36
    skipCUDAIfNotMiopenSuggestNHWC,
37
    skipCUDAIfRocm,
38
    skipCUDAIfRocmVersionLessThan,
39
    skipMeta,
40
)
41
from torch.testing._internal.common_dtype import (
42
    floating_and_complex_types_and,
43
    floating_types_and,
44
)
45
from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase
46
from torch.testing._internal.common_utils import (
47
    download_file,
48
    dtype2prec_DONTUSE,
49
    gradcheck,
50
    GRADCHECK_NONDET_TOL,
51
    gradgradcheck,
52
    instantiate_parametrized_tests,
53
    parametrize as parametrize_test,
54
    run_tests,
55
    set_default_dtype,
56
    skipIfNotMiopenSuggestNHWC,
57
    skipIfRocmVersionLessThan,
58
    subtest,
59
    TEST_SCIPY,
60
    TEST_WITH_ROCM,
61
)
62

63

64
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
65

66

67
if TEST_SCIPY:
68
    import scipy.ndimage
69
    import scipy.signal
70

71

72
class TestConvolutionNN(NNTestCase):
73
    _do_cuda_memory_leak_check = True
74
    _do_cuda_non_default_stream = True
75

76
    def test_conv_backcompat(self):
77
        from torch.serialization import SourceChangeWarning
78

79
        # This file was generated by running on PyTorch 1.0.1 on Python 2:
80
        #
81
        #     import torch
82
        #     from torch import nn
83
        #     m = nn.Conv2d(1, 1, 1)
84
        #     torch.save(m, 'legacy_conv2d.pt')
85
        #
86
        # NB: This Pickle also contains some Unicode data!
87
        path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt")
88
        with warnings.catch_warnings():
89
            warnings.simplefilter("ignore", SourceChangeWarning)
90
            # weights_only=False as this is legacy code that saves the model
91
            m = torch.load(path, encoding="utf-8", weights_only=False)
92
        input = torch.randn((1, 1, 1, 1), dtype=torch.float)
93
        self.assertEqual(m(input).size(), (1, 1, 1, 1))
94

95
    def test_invalid_conv1d(self):
96
        for dtype in [
97
            torch.half,
98
            torch.bfloat16,
99
            torch.float,
100
            torch.double,
101
            torch.cfloat,
102
            torch.cdouble,
103
        ]:
104
            module = nn.Conv1d(
105
                in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
106
            ).to(dtype)
107
            input = torch.randn(1, 3, 4).to(dtype)
108
            with self.assertRaisesRegex(
109
                RuntimeError,
110
                r"Calculated padded input size per channel: \(4\). "
111
                + r"Kernel size: \(10\). Kernel size can\'t be greater than actual input size",
112
            ):
113
                module(input)
114

115
            # Negative stride check
116
            module = nn.Conv1d(
117
                in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True
118
            ).to(dtype)
119
            input = torch.randn(1, 3, 4).to(dtype)
120
            with self.assertRaisesRegex(
121
                RuntimeError, "non-positive stride is not supported"
122
            ):
123
                module(input)
124

125
    def test_mismatch_shape_conv2d(self):
126
        for dtype in (torch.float, torch.cfloat):
127
            x = torch.randn(1, 10, 1, 28, 28, dtype=dtype)
128
            w = torch.randn(6, 1, 5, 5, dtype=dtype)
129

130
            with self.assertRaisesRegex(
131
                RuntimeError,
132
                r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got "
133
                + r"input of size: \[1, 10, 1, 28, 28\]",
134
            ):
135
                F.conv2d(x, w)
136

137
    def test_conv2d_discontiguous_weight(self):
138
        for dtype in (torch.float, torch.cfloat):
139
            # Test for https://github.com/pytorch/pytorch/issues/55781
140
            x = torch.ones(64, 16, 16, 16, dtype=dtype)
141
            weight = (
142
                torch.arange(0, 1.0, 1 / 2.0**10)
143
                .reshape(32, 16, 1, 2)
144
                .to(dtype)[:, :, :, ::2]
145
            )
146
            self.assertFalse(weight.is_contiguous())
147
            y = torch.nn.functional.conv2d(x, weight, None)
148
            if torch.backends.mkldnn.is_available():
149
                # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
150
                with torch.backends.mkldnn.flags(enabled=False):
151
                    y_ = torch.nn.functional.conv2d(x, weight, None)
152
                    self.assertEqual(y, y_)
153
            self.assertEqual(y.sum(), 4186112.0)
154

155
    def test_invalid_conv2d(self):
156
        for dtype in [
157
            torch.half,
158
            torch.bfloat16,
159
            torch.float,
160
            torch.double,
161
            torch.cfloat,
162
            torch.cdouble,
163
        ]:
164
            module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(
165
                dtype
166
            )
167
            input = torch.empty(1, 1, 4, 4).to(dtype)
168
            self.assertRaises(RuntimeError, lambda: module(input))
169

170
            module = nn.Conv2d(
171
                in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
172
            )
173
            input = torch.randn(1, 3, 1, 1)
174
            with self.assertRaisesRegex(
175
                RuntimeError,
176
                r"Calculated padded input size per channel: \(1 x 1\). "
177
                + r"Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size",
178
            ):
179
                module(input)
180

181
            # Negative stride check
182
            module = nn.Conv2d(
183
                in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True
184
            ).to(dtype)
185
            input = torch.randn(1, 3, 4, 4).to(dtype)
186
            with self.assertRaisesRegex(
187
                RuntimeError, "non-positive stride is not supported"
188
            ):
189
                module(input)
190

191
            # Zero stride check
192
            module = nn.Conv2d(
193
                in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True
194
            ).to(dtype)
195
            input = torch.randn(1, 3, 4, 4).to(dtype)
196
            with self.assertRaisesRegex(
197
                RuntimeError, "non-positive stride is not supported"
198
            ):
199
                module(input)
200

201
    def test_invalid_conv3d(self):
202
        for dtype in [
203
            torch.half,
204
            torch.bfloat16,
205
            torch.float,
206
            torch.double,
207
            torch.cfloat,
208
            torch.cdouble,
209
        ]:
210
            module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(
211
                dtype
212
            )
213
            input = torch.empty(1, 1, 4, 4, 4).to(dtype)
214
            self.assertRaises(RuntimeError, lambda: module(input))
215

216
            # Negative stride check
217
            module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2)
218
            input = torch.empty(1, 1, 4, 4, 4)
219
            with self.assertRaisesRegex(
220
                RuntimeError, "non-positive stride is not supported"
221
            ):
222
                module(input)
223

224
    def test_conv_invalid_groups(self):
225
        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
226
            torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
227
        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
228
            torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
229
        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
230
            torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)
231

232
    def test_Conv1d_module_same_padding(self):
233
        # Compare module against functional: without strides/dilation, asymmetric padding
234
        x = torch.rand(1, 1, 20)
235
        module = nn.Conv1d(
236
            in_channels=1, out_channels=1, kernel_size=10, padding="same"
237
        )
238
        expect = F.conv1d(x, module.weight, module.bias, padding="same")
239
        self.assertEqual(expect, module(x))
240

241
        # Test dilation, symmetric padding
242
        module = nn.Conv1d(
243
            in_channels=1, out_channels=1, kernel_size=10, padding="same", dilation=2
244
        )
245
        expect = F.conv1d(x, module.weight, module.bias, padding="same", dilation=2)
246
        self.assertEqual(expect, module(x))
247

248
        # Test non-zero padding_mode, requiring explicit padding
249
        module = nn.Conv1d(
250
            in_channels=1,
251
            out_channels=1,
252
            kernel_size=10,
253
            padding="same",
254
            padding_mode="replicate",
255
        )
256
        x_padded = F.pad(x, [4, 5], mode="replicate")
257
        expect = F.conv1d(x_padded, module.weight, module.bias, padding="valid")
258
        self.assertEqual(expect, module(x))
259
        self.assertEqual(x.size(), expect.size())
260

261
        # Test connstruction with invalid padding string raises
262
        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
263
            module = nn.Conv1d(
264
                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
265
            )
266

267
        # Test connstruction with same padding and strides raises
268
        with self.assertRaisesRegex(ValueError, "padding='same'"):
269
            module = nn.Conv1d(
270
                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
271
            )
272

273
    def test_Conv2d_module_same_padding(self):
274
        # Compare module against functional:
275
        # without strides/dilation, both symmetric and asymmetric padding
276
        x = torch.rand(1, 1, 9, 20)
277
        module = nn.Conv2d(
278
            in_channels=1, out_channels=1, kernel_size=(5, 10), padding="same"
279
        )
280
        expect = F.conv2d(x, module.weight, module.bias, padding="same")
281
        self.assertEqual(expect, module(x))
282

283
        # with dilation, symmetric padding
284
        module = nn.Conv2d(
285
            in_channels=1,
286
            out_channels=1,
287
            kernel_size=(3, 4),
288
            padding="same",
289
            dilation=(1, 2),
290
        )
291
        expect = F.conv2d(
292
            x, module.weight, module.bias, padding="same", dilation=(1, 2)
293
        )
294
        self.assertEqual(expect, module(x))
295

296
        # Test non-zero padding_mode, requiring explicit padding
297
        module = nn.Conv2d(
298
            in_channels=1,
299
            out_channels=1,
300
            kernel_size=(3, 4),
301
            padding="same",
302
            padding_mode="reflect",
303
        )
304
        x_padded = F.pad(x, [1, 2, 1, 1], mode="reflect")
305
        expect = F.conv2d(x_padded, module.weight, module.bias, padding="valid")
306
        self.assertEqual(expect, module(x))
307
        self.assertEqual(x.size(), expect.size())
308

309
        # Test connstruction with invalid padding string raises
310
        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
311
            module = nn.Conv2d(
312
                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
313
            )
314

315
        # Test connstruction with same padding and strides raises
316
        with self.assertRaisesRegex(ValueError, "padding='same'"):
317
            module = nn.Conv2d(
318
                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
319
            )
320
        with self.assertRaisesRegex(ValueError, "padding='same'"):
321
            module = nn.Conv2d(
322
                in_channels=3,
323
                out_channels=33,
324
                kernel_size=10,
325
                padding="same",
326
                stride=(1, 3),
327
            )
328
        with self.assertRaisesRegex(ValueError, "padding='same'"):
329
            module = nn.Conv2d(
330
                in_channels=3,
331
                out_channels=33,
332
                kernel_size=10,
333
                padding="same",
334
                stride=(4, 1),
335
            )
336

337
    def test_Conv3d_module_same_padding(self):
338
        # Compare module against functional:
339
        x = torch.rand(1, 1, 4, 4, 4)
340
        # without dilation, both symmetric and asymmetric padding
341
        module = nn.Conv3d(
342
            in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding="same"
343
        )
344
        expect = F.conv3d(x, module.weight, module.bias, padding="same")
345
        self.assertEqual(expect, module(x))
346

347
        # with dilation, both symmetric and asymmetric padding
348
        module = nn.Conv3d(
349
            in_channels=1,
350
            out_channels=1,
351
            kernel_size=(2, 3, 4),
352
            padding="same",
353
            dilation=(3, 2, 1),
354
        )
355
        expect = F.conv3d(
356
            x, module.weight, module.bias, padding="same", dilation=(3, 2, 1)
357
        )
358
        self.assertEqual(expect, module(x))
359

360
        # Test non-zero padding_mode, requiring explicit padding
361
        module = nn.Conv3d(
362
            in_channels=1,
363
            out_channels=1,
364
            kernel_size=(2, 3, 4),
365
            padding="same",
366
            padding_mode="circular",
367
        )
368
        x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode="circular")
369
        expect = F.conv3d(x_padded, module.weight, module.bias, padding="valid")
370
        self.assertEqual(expect, module(x))
371
        self.assertEqual(x.size(), expect.size())
372

373
        # Test connstruction with invalid padding string raises
374
        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
375
            module = nn.Conv3d(
376
                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
377
            )
378

379
        # Test connstruction with same padding and strides raises
380
        with self.assertRaisesRegex(ValueError, "padding='same'"):
381
            module = nn.Conv2d(
382
                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
383
            )
384
        with self.assertRaisesRegex(ValueError, "padding='same'"):
385
            module = nn.Conv2d(
386
                in_channels=3,
387
                out_channels=33,
388
                kernel_size=10,
389
                padding="same",
390
                stride=(1, 1, 3),
391
            )
392
        with self.assertRaisesRegex(ValueError, "padding='same'"):
393
            module = nn.Conv2d(
394
                in_channels=3,
395
                out_channels=33,
396
                kernel_size=10,
397
                padding="same",
398
                stride=(1, 4, 1),
399
            )
400
        with self.assertRaisesRegex(ValueError, "padding='same'"):
401
            module = nn.Conv2d(
402
                in_channels=3,
403
                out_channels=33,
404
                kernel_size=10,
405
                padding="same",
406
                stride=(5, 1, 1),
407
            )
408

409
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
410
    def test_thnn_conv_strided_padded_dilated(self):
411
        for convfn, dims, transposed in (
412
            (torch.nn.functional.conv2d, 2, False),
413
            (torch.nn.functional.conv_transpose2d, 2, True),
414
            (torch.nn.functional.conv3d, 3, False),
415
            (torch.nn.functional.conv_transpose3d, 3, True),
416
        ):
417
            for stride, padding, dilation in (
418
                (2, 0, 1),
419
                (1, 1, 1),
420
                (2, 1, 1),
421
                (1, 0, 2),
422
            ):
423
                kwargs = {"stride": stride, "padding": padding, "dilation": dilation}
424
                inp_shape = (1, 2) + dims * (4,)
425
                weight_shape = (2, 2) + dims * (1,)
426
                inputs = torch.randn(
427
                    inp_shape, dtype=torch.double, device="cuda", requires_grad=True
428
                )
429
                weight = torch.randn(
430
                    weight_shape, dtype=torch.double, device="cuda", requires_grad=True
431
                )
432
                bias = torch.randn(
433
                    2, dtype=torch.double, device="cuda", requires_grad=True
434
                )
435
                with torch.backends.cudnn.flags(enabled=False):
436
                    res = convfn(inputs, weight, bias, **kwargs)
437
                res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs)
438
                self.assertEqual(res, res_cpu)
439
                with torch.backends.cudnn.flags(enabled=False):
440
                    torch.autograd.gradcheck(
441
                        lambda x, w, b: convfn(x, w, b, **kwargs),
442
                        (inputs, weight, bias),
443
                    )
444
                    torch.autograd.gradcheck(
445
                        lambda x, w, b: convfn(x, w, b, **kwargs),
446
                        (inputs.cpu(), weight.cpu(), bias.cpu()),
447
                    )
448

449
    def test_Conv2d_inconsistent_types(self):
450
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float)
451
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double)
452
        # inconsistent types should raise an exception
453
        self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
454
        # but it should work with the same type
455
        nn.functional.conv2d(inputs.float(), weights.float())
456

457
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
458
    def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self):
459
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
460
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
461
        bias = torch.randn(1, dtype=torch.double, device="cuda")
462

463
        with torch.backends.cudnn.flags(enabled=False):
464
            # inconsistent types should raise an exception
465
            self.assertRaises(
466
                RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
467
            )
468
            self.assertRaises(
469
                RuntimeError,
470
                lambda: nn.functional.conv2d(inputs, weights.float(), bias),
471
            )
472

473
            # but it should work with the same type
474
            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
475

476
    def test_Conv2d_1x1(self):
477
        in_channels = 2
478
        out_channels = 2
479
        mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double)
480
        input = torch.randn(
481
            1, in_channels, 5, 5, requires_grad=True, dtype=torch.double
482
        )
483
        for enabled in (False, True):
484
            with torch.backends.mkldnn.flags(enabled=enabled):
485
                gradcheck(F.conv2d, (input, mod.weight))
486

487
    def test_Conv2d_OneDNN(self):
488
        def run_once(group_val=24, dilation=1):
489
            ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32)
490
            weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32)
491
            op = torch.nn.Conv2d(
492
                in_channels=group_val,
493
                out_channels=group_val,
494
                kernel_size=[3, 3],
495
                stride=[2, 2],
496
                padding=[1, 1],
497
                dilation=[dilation, dilation],
498
                groups=group_val,
499
                bias=False,
500
                padding_mode="zeros",
501
            )
502

503
            op.weight.data = weights
504
            res = op(ifm)
505
            grad_in = torch.ones(res.shape, dtype=torch.float32)
506
            res.backward(grad_in)
507
            return op.weight.grad
508

509
        for gorup_val in (24, 48, 23, 25):
510
            for dilation in (1, 2):
511
                with torch.backends.mkldnn.flags(enabled=False):
512
                    without_onednn = run_once(gorup_val, dilation)
513

514
                with torch.backends.mkldnn.flags(enabled=True):
515
                    with_onednn = run_once(gorup_val, dilation)
516

517
                self.assertEqual(without_onednn, with_onednn)
518

519
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
520
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
521
    def test_cudnn_non_contiguous(self):
522
        x = torch.randn(192, 16, 50).cuda()
523
        x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1)
524
        m = torch.nn.Conv1d(
525
            in_channels=16, out_channels=32, kernel_size=2, bias=True
526
        ).cuda()
527
        result = m(x)
528

529
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
530
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
531
    def test_cudnn_not_mutate_stride(self):
532
        weight = torch.randn(64, 64, 1, 1)
533
        x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last)
534
        weight_stride = weight.stride()
535

536
        def conv(x, weight):
537
            return torch.convolution(
538
                x,
539
                weight,
540
                stride=(1, 1),
541
                padding=(0, 0),
542
                dilation=(1, 1),
543
                transposed=False,
544
                output_padding=(0, 0),
545
                groups=1,
546
                bias=None,
547
            )
548

549
        # should have run in nhwc without mutating input strides
550
        out_nhwc = conv(x, weight)
551
        self.assertEqual(weight.stride(), weight_stride)
552
        self.assertTrue(out_nhwc.is_contiguous(memory_format=torch.channels_last))
553

554
        x = x.contiguous(memory_format=torch.contiguous_format)
555
        out_c = conv(x, weight)
556
        self.assertTrue(out_c.is_contiguous(memory_format=torch.contiguous_format))
557
        self.assertEqual(out_c, out_nhwc)
558
        self.assertEqual(weight.stride(), weight_stride)
559

560
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
561
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
562
    def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
563
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
564
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
565
        bias = torch.randn(1, dtype=torch.double, device="cuda")
566

567
        with torch.backends.cudnn.flags(enabled=True):
568
            # inconsistent types should raise an exception
569
            self.assertRaises(
570
                RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
571
            )
572
            self.assertRaises(
573
                RuntimeError,
574
                lambda: nn.functional.conv2d(inputs, weights.float(), bias),
575
            )
576

577
            # but it should work with the same type
578
            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
579

580
    def test_Conv2d_missing_argument(self):
581
        c = nn.Conv2d(3, 3, 3)
582
        self.assertRaises(TypeError, lambda: c(None))
583

584
    def test_Conv2d_backward_twice(self):
585
        input = torch.randn(2, 3, 5, 5)
586
        c = nn.Conv2d(3, 3, 3)
587
        o1 = c(input)
588
        o1.sum().backward()
589
        self.assertRaisesRegex(
590
            RuntimeError, "Specify retain_graph=True", lambda: o1.sum().backward()
591
        )
592

593
    def test_conv_modules_raise_error_on_incorrect_input_size(self):
594
        for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
595
            modules = [
596
                nn.Conv1d(3, 8, 3).to(dtype),
597
                nn.ConvTranspose1d(3, 8, 3).to(dtype),
598
                nn.Conv2d(3, 8, 3).to(dtype),
599
                nn.ConvTranspose2d(3, 8, 3).to(dtype),
600
                nn.Conv3d(3, 8, 3).to(dtype),
601
                nn.ConvTranspose3d(3, 8, 3).to(dtype),
602
            ]
603

604
            invalid_input_dims = [(1, 4), (1, 4), (2, 5), (2, 5), (3, 6), (3, 6)]
605

606
            for invalid_dims, module in zip(invalid_input_dims, modules):
607
                for dims in invalid_dims:
608
                    input = torch.empty(torch.Size((3,) * dims))
609
                    self.assertRaises(RuntimeError, lambda: module(input))
610

611
    def test_conv_shapecheck(self):
612
        def test(should_raise, module, input_size, dtype):
613
            input = torch.empty(3, *input_size).to(dtype)
614
            if should_raise:
615
                self.assertRaises(RuntimeError, lambda: module(input))
616
            else:
617
                # just run it to ensure no exception raised.
618
                module(input)
619

620
        for dtype in [
621
            torch.half,
622
            torch.bfloat16,
623
            torch.float,
624
            torch.double,
625
            torch.cfloat,
626
            torch.cdouble,
627
        ]:
628
            # Conv1d
629
            test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype)
630
            test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype)
631
            test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype)
632
            test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype)
633
            test(
634
                False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype
635
            )
636

637
            # Conv2d
638
            test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype)
639
            test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype)
640
            test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype)
641

642
            # Conv3D
643
            test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype)
644
            test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype)
645
            test(
646
                False,
647
                nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype),
648
                (1, 2, 2, 2),
649
                dtype,
650
            )
651

652
    def test_ConvTranspose2d_output_size(self):
653
        m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
654
        i = torch.randn(2, 3, 6, 6)
655
        for h in range(15, 22):
656
            for w in range(15, 22):
657
                if 18 <= h <= 20 and 18 <= w <= 20:
658
                    output = m(i, output_size=(h, w))
659
                    self.assertEqual(output.size()[2:], (h, w))
660
                else:
661
                    self.assertRaises(ValueError, lambda: m(i, (h, w)))
662

663
    def test_ConvTranspose2d_output_size_downsample_upsample(self):
664
        b, c, hid_c = 2, 3, 2
665
        for h in range(13, 24):
666
            for w in range(13, 17):
667
                for k in range(2, 5):
668
                    for d in range(1, 5):
669
                        for s in range(1, 4):
670
                            for p in range(3):
671
                                conv = nn.Conv2d(
672
                                    in_channels=c,
673
                                    out_channels=hid_c,
674
                                    kernel_size=k,
675
                                    stride=s,
676
                                    padding=p,
677
                                    dilation=d,
678
                                )
679

680
                                t_conv = nn.ConvTranspose2d(
681
                                    in_channels=hid_c,
682
                                    out_channels=c,
683
                                    kernel_size=k,
684
                                    stride=s,
685
                                    padding=p,
686
                                    dilation=d,
687
                                )
688

689
                                i = torch.randn(b, c, h, w)
690

691
                                out = t_conv(conv(i), output_size=i.shape)
692

693
                                self.assertEqual(out.size()[2:], i.size()[2:])
694

695
    def test_ConvTranspose3d_correct_output_size(self):
696
        # Check that ConvTranspose3d can take a 5d output_size.
697
        m = nn.ConvTranspose3d(2, 2, 2)
698
        i = torch.rand(1, 2, 1, 1, 1)
699
        out = m(i, output_size=(1, 2, 2, 2, 2))
700

701
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
702
    def test_ConvTranspose2d_half_cublas_gemm(self):
703
        with torch.backends.cudnn.flags(enabled=False):
704
            inputs = torch.randn(1, 1, 16, 16, device="cuda", dtype=torch.half)
705
            deconv = (
706
                nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, output_padding=1)
707
                .cuda()
708
                .half()
709
            )
710
            output = deconv(inputs)
711
            output.mean().backward()
712

713
    # For https://github.com/pytorch/pytorch/pull/1273
714
    # Almost identical to the above `test_Conv2d_naive_groups`
715
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
716
    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
717
    def test_Conv2d_groups_nobias(self):
718
        dev_dtypes = [("cpu", torch.float)]
719
        if TEST_CUDA:
720
            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
721
        if AMPERE_OR_ROCM:
722
            dev_dtypes += [("cuda", torch.bfloat16)]
723
        for device, dtype in dev_dtypes:
724
            m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype)
725
            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
726
            output = m(i)
727
            grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
728
            output.backward(grad_output)
729

730
            m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
731
            m1.weight.data.copy_(m.weight.data[:2])
732
            i1 = i.data[:, :2].contiguous().requires_grad_(True)
733
            output1 = m1(i1)
734
            output1.backward(grad_output[:, :2].contiguous())
735

736
            m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
737
            m2.weight.data.copy_(m.weight.data[2:])
738
            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
739
            output2 = m2(i2)
740
            output2.backward(grad_output[:, 2:].contiguous())
741

742
            self.assertEqual(output, torch.cat([output1, output2], 1))
743
            self.assertEqual(
744
                i.grad.data,
745
                torch.cat([i1.grad.data, i2.grad.data], 1),
746
                atol=dtype2prec_DONTUSE[dtype],
747
                rtol=0,
748
            )
749
            self.assertEqual(
750
                m.weight.grad.data,
751
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
752
                atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
753
                rtol=0,
754
            )
755

756
    # Almost identical to the above `test_Conv2d_naive_groups`
757
    # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16
758
    # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686
759
    # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
760
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
761
    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
762
    def test_Conv2d_groups_nobias_v2(self):
763
        torch.manual_seed(123)
764
        dev_dtypes = [("cpu", torch.float)]
765
        if TEST_CUDA:
766
            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
767
        if AMPERE_OR_ROCM:
768
            dev_dtypes += [("cuda", torch.bfloat16)]
769
        for device, dtype in dev_dtypes:
770
            m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype)
771
            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
772
            output = m(i)
773
            grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype)
774
            output.backward(grad_output)
775

776
            m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
777
            m1.weight.data.copy_(m.weight.data[:8])
778
            i1 = i.data[:, :2].contiguous().requires_grad_(True)
779
            output1 = m1(i1)
780
            output1.backward(grad_output[:, :8].contiguous())
781

782
            m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
783
            m2.weight.data.copy_(m.weight.data[8:])
784
            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
785
            output2 = m2(i2)
786
            output2.backward(grad_output[:, 8:].contiguous())
787

788
            self.assertEqual(output, torch.cat([output1, output2], 1))
789
            self.assertEqual(
790
                i.grad.data,
791
                torch.cat([i1.grad.data, i2.grad.data], 1),
792
                atol=dtype2prec_DONTUSE[dtype],
793
                rtol=0,
794
            )
795
            self.assertEqual(
796
                m.weight.grad.data,
797
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
798
                atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
799
                rtol=0,
800
            )
801

802
    # CPU-only test for group conv3d fast implementation using bmm
803
    # See: https://github.com/pytorch/pytorch/pull/36355
804
    def test_Conv3d_groups_nobias(self):
805
        torch.manual_seed(123)
806
        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float)
807
        i = torch.randn(
808
            2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
809
        )
810
        output = m(i)
811
        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
812
        output.backward(grad_output)
813

814
        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
815
        m1.weight.data.copy_(m.weight.data[:8])
816
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
817
        output1 = m1(i1)
818
        output1.backward(grad_output[:, :8].contiguous())
819

820
        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
821
        m2.weight.data.copy_(m.weight.data[8:])
822
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
823
        output2 = m2(i2)
824
        output2.backward(grad_output[:, 8:].contiguous())
825

826
        self.assertEqual(output, torch.cat([output1, output2], 1))
827
        self.assertEqual(
828
            i.grad.data,
829
            torch.cat([i1.grad.data, i2.grad.data], 1),
830
            atol=dtype2prec_DONTUSE[torch.float],
831
            rtol=0,
832
        )
833
        self.assertEqual(
834
            m.weight.grad.data,
835
            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
836
            atol=dtype2prec_DONTUSE[torch.float],
837
            rtol=dtype2prec_DONTUSE[torch.float],
838
        )
839

840
    def test_Conv3d_groups_wbias(self):
841
        torch.manual_seed(123)
842
        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float)
843
        i = torch.randn(
844
            2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
845
        )
846
        output = m(i)
847
        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
848
        output.backward(grad_output)
849

850
        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
851
        m1.weight.data.copy_(m.weight.data[:8])
852
        m1.bias.data.copy_(m.bias.data[:8])
853
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
854
        output1 = m1(i1)
855
        output1.backward(grad_output[:, :8].contiguous())
856

857
        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
858
        m2.weight.data.copy_(m.weight.data[8:])
859
        m2.bias.data.copy_(m.bias.data[8:])
860
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
861
        output2 = m2(i2)
862
        output2.backward(grad_output[:, 8:].contiguous())
863

864
        self.assertEqual(output, torch.cat([output1, output2], 1))
865
        self.assertEqual(
866
            i.grad.data,
867
            torch.cat([i1.grad.data, i2.grad.data], 1),
868
            atol=dtype2prec_DONTUSE[torch.float],
869
            rtol=dtype2prec_DONTUSE[torch.float],
870
        )
871
        self.assertEqual(
872
            m.weight.grad.data,
873
            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
874
            atol=dtype2prec_DONTUSE[torch.float],
875
            rtol=dtype2prec_DONTUSE[torch.float],
876
        )
877
        self.assertEqual(
878
            m.bias.grad.data,
879
            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
880
            atol=dtype2prec_DONTUSE[torch.float],
881
            rtol=dtype2prec_DONTUSE[torch.float],
882
        )
883

884
    def test_conv_tbc(self):
885
        with set_default_dtype(torch.double):
886
            inp = torch.randn(9, 4, 5, requires_grad=True)
887
            weight = torch.randn(3, 5, 6, requires_grad=True)
888
            bias = torch.randn(6, requires_grad=True)
889

890
            gradcheck(
891
                lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)
892
            )
893

894
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
895
    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
896
    @skipIfRocmVersionLessThan((4, 3))
897
    @skipIfNotMiopenSuggestNHWC
898
    def test_grouped_conv_cudnn_nhwc_support(self):
899
        # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
900
        input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
901
            memory_format=torch.channels_last
902
        )
903
        weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(
904
            memory_format=torch.channels_last
905
        )
906
        out = torch.convolution(
907
            input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4
908
        )
909
        input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(
910
            memory_format=torch.channels_last
911
        )
912
        out_transpose = torch.convolution(
913
            input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4
914
        )
915

916
    @unittest.expectedFailure
917
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
918
    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
919
    def test_conv_cudnn_memory_layout_dominance(self):
920
        # desired behavior here is to have the memory_layout of conv.weight to
921
        # dominante the layout of output.
922
        # which is not the same as current behavior, we'll fix this in
923
        # following up PRs and remove the `expectedFailure` tag
924
        input = torch.randint(
925
            1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True
926
        )
927
        conv = nn.Conv2d(8, 4, 3).cuda().float()
928

929
        out = conv(input)
930
        self.assertTrue(out.is_contiguous())
931

932
        input = input.contiguous(memory_format=torch.channels_last)
933
        out = conv(input)
934
        self.assertTrue(out.is_contiguous())
935

936
        conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last)
937
        out = conv(input)
938
        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
939

940
        input = input.contiguous()
941
        out = conv(input)
942
        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
943

944
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
945
    def test_cudnn_noncontiguous_weight(self):
946
        # Noncontiguous weights must be contiguous() before being
947
        # passed to cuDNN
948
        input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3)
949
        weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2)
950
        weights2 = (
951
            torch.tensor([1], dtype=torch.double, device="cuda")
952
            .expand(1, 1, 2)
953
            .contiguous()
954
        )
955
        self.assertEqual(
956
            F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
957
            F.conv1d(input, weights2, bias=None, stride=2, dilation=2),
958
        )
959

960
    def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient="input"):
961
        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
962
            for batch, stride, padding, chan_in, chan_out, dilation in product(
963
                [1, 2], [1, 2], [0, 1, 2], [2], [3], [1]
964
            ):
965
                for has_bias in [True, False]:
966
                    input_shape = [batch, chan_in]
967
                    weight_shape = [chan_out, chan_in]
968
                    for _ in range(dim):
969
                        input_shape.append(inp_size)
970
                        weight_shape.append(kern)
971

972
                    input = torch.randn(input_shape, requires_grad=True)
973
                    weight = torch.randn(weight_shape, requires_grad=True)
974
                    if has_bias:
975
                        bias = torch.randn([chan_out], requires_grad=True)
976
                    output = func_forward(
977
                        input,
978
                        weight,
979
                        stride=stride,
980
                        padding=padding,
981
                        dilation=dilation,
982
                        bias=bias,
983
                    )
984

985
                    gradient_o = torch.randn(output.shape)
986
                    gradient_w = torch.autograd.grad(
987
                        output, input if (gradient == "input") else weight, gradient_o
988
                    )
989

990
                    self.assertEqual(
991
                        gradient_w[0],
992
                        func_backward(
993
                            input_shape if (gradient == "input") else input,
994
                            weight_shape if (gradient == "weight") else weight,
995
                            gradient_o,
996
                            stride=stride,
997
                            padding=padding,
998
                            dilation=dilation,
999
                        ),
1000
                    )
1001

1002
    def test_grad_conv1d_input(self):
1003
        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, "input")
1004

1005
    def test_grad_conv1d_weight(self):
1006
        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, "weight")
1007

1008
    def test_grad_conv2d_input(self):
1009
        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, "input")
1010

1011
    def test_grad_conv2d_weight(self):
1012
        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, "weight")
1013

1014
    def test_grad_conv3d_input(self):
1015
        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, "input")
1016

1017
    def test_grad_conv3d_weight(self):
1018
        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, "weight")
1019

1020
    @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable")
1021
    def test_nnpack_conv(self):
1022
        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
1023
            for batch, stride, padding, chan_in, chan_out in product(
1024
                [1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]
1025
            ):
1026
                for has_bias in [True, False]:
1027
                    input_shape = [batch, chan_in]
1028
                    weight_shape = [chan_out, chan_in]
1029
                    for _ in range(2):
1030
                        input_shape.append(inp_size)
1031
                        weight_shape.append(kern)
1032

1033
                    input = torch.randn(
1034
                        input_shape, requires_grad=True, dtype=torch.float
1035
                    )
1036
                    weight = torch.randn(
1037
                        weight_shape, requires_grad=True, dtype=torch.float
1038
                    )
1039
                    if has_bias:
1040
                        bias = torch.randn(
1041
                            [chan_out], requires_grad=True, dtype=torch.float
1042
                        )
1043
                    output = torch._nnpack_spatial_convolution(
1044
                        input, weight, stride=stride, padding=padding, bias=bias
1045
                    )
1046
                    output_expected = torch.nn.functional.conv2d(
1047
                        input, weight, stride=stride, padding=padding, bias=bias
1048
                    )
1049
                    self.assertEqual(output, output_expected, atol=3e-4, rtol=0)
1050

1051
                    gradient_o = torch.randn(output.shape, dtype=torch.float)
1052

1053
                    grads = torch.autograd.grad(output, [input, weight], gradient_o)
1054
                    grads_expected = torch.autograd.grad(
1055
                        output_expected, [input, weight], gradient_o
1056
                    )
1057
                    for gr, gr_expected in zip(grads, grads_expected):
1058
                        self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0)
1059

1060
    def test_conv_padding_mode(self):
1061
        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
1062
            nn.Conv2d(3, 3, 3, padding_mode="xyz")
1063

1064
        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
1065
            nn.Conv2d(3, 3, 3, padding_mode=3)
1066

1067
        with self.assertRaisesRegex(ValueError, 'Only "zeros" '):
1068
            nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")
1069

1070
    def test_functional_grad_conv(self):
1071
        # Conv 1D
1072
        input = torch.randn(1, 1, 5, requires_grad=True)
1073
        weight = torch.randn(1, 1, 3, requires_grad=True)
1074
        output = F.conv1d(input, weight, dilation=2)
1075
        grad_output = torch.randn(output.shape)
1076

1077
        grad_input_autograd, grad_weight_autograd = torch.autograd.grad(
1078
            output, (input, weight), grad_output
1079
        )
1080

1081
        grad_input_functional = torch.nn.grad.conv1d_input(
1082
            input.shape, weight, grad_output, dilation=2
1083
        )
1084
        self.assertEqual(grad_input_functional, grad_input_autograd)
1085

1086
        grad_weight_functional = torch.nn.grad.conv1d_weight(
1087
            input, weight.shape, grad_output, dilation=2
1088
        )
1089
        self.assertEqual(grad_weight_functional, grad_weight_autograd)
1090

1091
        # Conv 2D
1092
        input = torch.randn(1, 1, 5, 5, requires_grad=True)
1093
        weight = torch.randn(1, 1, 3, 3, requires_grad=True)
1094
        output = F.conv2d(input, weight, dilation=2)
1095
        grad_output = torch.randn(output.shape)
1096

1097
        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
1098
            output, (input, weight), grad_output
1099
        )
1100

1101
        grad_input_functional = torch.nn.grad.conv2d_input(
1102
            input.shape, weight, grad_output, dilation=2
1103
        )
1104
        self.assertEqual(grad_input_functional, grad_input_autograd)
1105

1106
        grad_weight_functional = torch.nn.grad.conv2d_weight(
1107
            input, weight.shape, grad_output, dilation=2
1108
        )
1109
        self.assertEqual(grad_weight_functional, grad_weight_autograd)
1110

1111
        # Conv 3D
1112
        input = torch.randn(1, 1, 5, 5, 5, requires_grad=True)
1113
        weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True)
1114
        output = F.conv3d(input, weight, dilation=2)
1115
        grad_output = torch.randn(output.shape)
1116

1117
        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
1118
            output, (input, weight), grad_output
1119
        )
1120

1121
        grad_input_functional = torch.nn.grad.conv3d_input(
1122
            input.shape, weight, grad_output, dilation=2
1123
        )
1124
        self.assertEqual(grad_input_functional, grad_input_autograd)
1125

1126
        grad_weight_functional = torch.nn.grad.conv3d_weight(
1127
            input, weight.shape, grad_output, dilation=2
1128
        )
1129
        self.assertEqual(grad_weight_functional, grad_weight_autograd)
1130

1131
    def test_functional_grad_conv2d(self):
1132
        BATCH_SIZE = 4
1133
        IN_CH = 8
1134
        OUT_CH = 16
1135
        SPATIAL = 32
1136

1137
        def _test_conv2d(stride, kernel_size, groups, dilation):
1138
            padding = kernel_size // 2
1139

1140
            input = (
1141
                torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL)
1142
                .uniform_(-8.0, 8.0)
1143
                .requires_grad_(True)
1144
            )
1145

1146
            weight = (
1147
                torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size)
1148
                .uniform_(-4.0, 4.0)
1149
                .requires_grad_(True)
1150
            )
1151

1152
            output = F.conv2d(
1153
                input,
1154
                weight,
1155
                stride=stride,
1156
                padding=padding,
1157
                dilation=dilation,
1158
                groups=groups,
1159
            )
1160

1161
            grad_output = torch.randn(output.shape)
1162

1163
            (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
1164
                output, (input, weight), grad_output
1165
            )
1166

1167
            grad_input_functional = torch.nn.grad.conv2d_input(
1168
                input.shape,
1169
                weight,
1170
                grad_output,
1171
                stride=stride,
1172
                padding=padding,
1173
                dilation=dilation,
1174
                groups=groups,
1175
            )
1176
            self.assertEqual(grad_input_functional, grad_input_autograd)
1177

1178
            grad_weight_functional = torch.nn.grad.conv2d_weight(
1179
                input,
1180
                weight.shape,
1181
                grad_output,
1182
                stride=stride,
1183
                padding=padding,
1184
                dilation=dilation,
1185
                groups=groups,
1186
            )
1187
            self.assertEqual(grad_weight_functional, grad_weight_autograd)
1188

1189
        strides = [1, 2]
1190
        kernel_sizes = [1, 3, 5]
1191
        groups = [1, 2, 4]
1192
        dilates = [1, 2]
1193

1194
        for s, k, g, d in product(strides, kernel_sizes, groups, dilates):
1195
            _test_conv2d(s, k, g, d)
1196

1197
    def test_permute_conv2d_issue_120211(self):
1198
        def reproducer(radius: int):
1199
            image = torch.rand(1, 1024, 1024, 3)
1200
            image = image.permute(0, 3, 1, 2)
1201
            kernel_x = torch.zeros([3, 1, 1, radius * 2 + 1], device=image.device)
1202
            image = torch.nn.functional.conv2d(image, kernel_x, groups=image.shape[-3])
1203

1204
        for i in range(0, 128):
1205
            # This should not fail
1206
            reproducer(radius=i)
1207

1208
    def test_conv3d_issue_120406(self):
1209
        # This should not fail
1210
        F.conv3d(torch.ones(2, 3, 8, 9, 26), torch.ones(3, 1, 1, 1, 17), groups=3)
1211

1212
    def test_conv1d_issue_120547(self):
1213
        weight = torch.ones([16, 1, 32])
1214
        bias = torch.ones([16])
1215
        stride, padding, dilation, groups = (1, 16, 1, 16)
1216
        input = torch.rand((1, 1, 16))
1217
        input = input.transpose(1, 2)
1218
        # This should not fail
1219
        F.conv1d(input, weight, bias, stride, padding, dilation, groups)
1220

1221

1222
class TestConvolutionNNDeviceType(NNTestCase):
1223
    def run_conv_double_back_test(
1224
        self,
1225
        kern,
1226
        stride,
1227
        padding,
1228
        chan_in,
1229
        chan_out,
1230
        batch_size,
1231
        inp_size,
1232
        dilation,
1233
        no_weight,
1234
        groups=1,
1235
        use_cuda=False,
1236
        use_bias=True,
1237
        dtype=torch.double,
1238
    ):
1239
        if use_cuda:
1240
            device = torch.device("cuda")
1241
        else:
1242
            device = torch.device("cpu")
1243

1244
        x = torch.randn(
1245
            batch_size,
1246
            chan_in,
1247
            inp_size,
1248
            inp_size,
1249
            device=device,
1250
            dtype=dtype,
1251
            requires_grad=True,
1252
        )
1253
        weight = torch.randn(
1254
            chan_out,
1255
            chan_in // groups,
1256
            kern,
1257
            kern,
1258
            device=device,
1259
            dtype=dtype,
1260
            requires_grad=not no_weight,
1261
        )
1262
        if use_bias:
1263
            bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
1264
        else:
1265
            bias = None
1266

1267
        def func(*inputs):
1268
            if use_bias:
1269
                lx, lweight, lbias = inputs
1270
            else:
1271
                lx, lweight = inputs
1272
                lbias = None
1273
            # We disable cudnn during forward to avoid finite difference imprecision issues
1274
            with cudnn.flags(enabled=False):
1275
                out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
1276
            return out
1277

1278
        if use_bias:
1279
            inputs = x, weight, bias
1280
        else:
1281
            inputs = x, weight
1282

1283
        dummy_out = func(*inputs)
1284
        grad_y = torch.randn_like(
1285
            dummy_out, device=device, dtype=dtype, requires_grad=True
1286
        )
1287

1288
        # Issue #15353: test mkldnn double backward, don't run gradgradcheck due
1289
        # to imprecision issues
1290
        if dtype == torch.float:
1291
            (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
1292
            return g.requires_grad
1293

1294
        return gradgradcheck(func, inputs, (grad_y,))
1295

1296
    @onlyCUDA
1297
    @skipCUDAIfNoCudnn
1298
    @dtypes(
1299
        *floating_and_complex_types_and(
1300
            torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []
1301
        )
1302
    )
1303
    def test_Conv2d_deterministic_cudnn(self, device, dtype):
1304
        inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True)
1305
        with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
1306
            conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
1307
            conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
1308
            conv2.bias.data.copy_(conv1.bias.data)
1309
            conv2.weight.data.copy_(conv1.weight.data)
1310
            out1 = conv1(inputs)
1311
            out2 = conv2(inputs)
1312
            self.assertEqual(out1, out2, atol=0.0, rtol=0)
1313
            y = torch.randn(out1.size(), device=device, dtype=dtype)
1314
            out1.backward(y)
1315
            out2.backward(y)
1316
            self.assertEqual(
1317
                conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0
1318
            )
1319
            self.assertEqual(
1320
                conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0
1321
            )
1322

1323
    @onlyCUDA
1324
    @dtypes(
1325
        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
1326
    )
1327
    def test_Conv2d_large_workspace(self, device, dtype):
1328
        # These sizes require huge cuDNN workspaces. Make sure we choose a
1329
        # reasonable algorithm that does not run out of memory
1330
        sizes = [
1331
            (1, 256, 109, 175),
1332
            (1, 256, 80, 128),
1333
            (1, 256, 120, 192),
1334
        ]
1335

1336
        def run_test(benchmark):
1337
            with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark):
1338
                conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(
1339
                    device, dtype
1340
                )
1341
                for size in sizes:
1342
                    x = torch.randn(size, device=device, dtype=dtype)
1343
                    out = conv(x.detach().clone().requires_grad_())
1344
                    out.backward(torch.ones_like(out))
1345

1346
        run_test(benchmark=False)
1347
        run_test(benchmark=True)
1348

1349
    @onlyCUDA
1350
    @dtypes(torch.half, torch.float)
1351
    def test_ConvTranspose2d_large_output_padding(self, device, dtype):
1352
        net1 = torch.nn.ConvTranspose2d(
1353
            128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
1354
        ).to(device=device, dtype=dtype)
1355
        net2 = torch.nn.ConvTranspose2d(
1356
            64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
1357
        ).to(device=device, dtype=dtype)
1358
        net3 = torch.nn.ConvTranspose2d(
1359
            32, 3, kernel_size=3, stride=2, padding=1, output_padding=1
1360
        ).to(device=device, dtype=dtype)
1361
        x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
1362
        x = net1(x)
1363
        x = net2(x)
1364
        x = net3(x)
1365
        x.backward(torch.randn_like(x))
1366
        torch.cuda.synchronize()
1367

1368
    @onlyCUDA
1369
    @dtypes(torch.float, torch.double, torch.half)
1370
    # Very similar to test_Conv2d_naive_groups but with special care to handle
1371
    # the number of groups == number of input channels
1372
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
1373
    @tf32_on_and_off(0.01)
1374
    def test_Conv2d_depthwise_naive_groups(self, device, dtype):
1375
        for depth_multiplier in [1, 2]:
1376
            m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
1377
                device, dtype
1378
            )
1379
            i = (
1380
                torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype)
1381
                .div_(2)
1382
                .requires_grad_()
1383
            )
1384
            output = m(i)
1385
            grad_output = (
1386
                torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype)
1387
                / 2
1388
            )
1389
            output.backward(grad_output)
1390

1391
            offset = 1 * depth_multiplier
1392

1393
            m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1394
            m1.weight.data = m.weight.data[:offset].clone()
1395
            m1.bias.data = m.bias.data[:offset].clone()
1396
            i1 = i.detach()[:, :1].clone().requires_grad_()
1397
            output1 = m1(i1)
1398
            output1.backward(grad_output[:, :offset].contiguous())
1399

1400
            m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1401
            m2.weight.data.copy_(m.weight.data[offset:])
1402
            m2.bias.data.copy_(m.bias.data[offset:])
1403
            i2 = i.detach()[:, 1:].clone().requires_grad_()
1404
            output2 = m2(i2)
1405
            output2.backward(grad_output[:, offset:].contiguous())
1406

1407
            self.assertEqual(
1408
                output,
1409
                torch.cat([output1, output2], 1),
1410
                atol=dtype2prec_DONTUSE[dtype],
1411
                rtol=0,
1412
            )
1413
            self.assertEqual(
1414
                i.grad.data,
1415
                torch.cat([i1.grad.data, i2.grad.data], 1),
1416
                atol=dtype2prec_DONTUSE[dtype],
1417
                rtol=0,
1418
            )
1419
            self.assertEqual(
1420
                m.bias.grad.data,
1421
                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
1422
                atol=dtype2prec_DONTUSE[dtype],
1423
                rtol=0,
1424
            )
1425
            self.assertEqual(
1426
                m.weight.grad.data,
1427
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
1428
                atol=dtype2prec_DONTUSE[dtype],
1429
                rtol=0,
1430
            )
1431

1432
    @onlyCUDA
1433
    @dtypes(torch.float, torch.double, torch.half)
1434
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
1435
    @tf32_on_and_off(0.01)
1436
    def test_Conv3d_depthwise_naive_groups(self, device, dtype):
1437
        for depth_multiplier in [1, 2]:
1438
            m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
1439
                device, dtype
1440
            )
1441
            i = (
1442
                torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype)
1443
                .div_(2)
1444
                .requires_grad_()
1445
            )
1446
            output = m(i)
1447
            grad_output = (
1448
                torch.randn(
1449
                    2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype
1450
                )
1451
                / 2
1452
            )
1453
            output.backward(grad_output)
1454

1455
            offset = 1 * depth_multiplier
1456

1457
            m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1458
            m1.weight.data = m.weight.data[:offset].clone()
1459
            m1.bias.data = m.bias.data[:offset].clone()
1460
            i1 = i.detach()[:, :1].clone().requires_grad_()
1461
            output1 = m1(i1)
1462
            output1.backward(grad_output[:, :offset].contiguous())
1463

1464
            m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
1465
            m2.weight.data.copy_(m.weight.data[offset:])
1466
            m2.bias.data.copy_(m.bias.data[offset:])
1467
            i2 = i.detach()[:, 1:].clone().requires_grad_()
1468
            output2 = m2(i2)
1469
            output2.backward(grad_output[:, offset:].contiguous())
1470
            is_cuda_sm86 = device.startswith(
1471
                "cuda"
1472
            ) and torch.cuda.get_device_capability(0) == (8, 6)
1473
            atol, rtol = (
1474
                (3e-4, 3e-2)
1475
                if dtype == torch.float32 and is_cuda_sm86
1476
                else (dtype2prec_DONTUSE[dtype], 0)
1477
            )
1478

1479
            self.assertEqual(
1480
                output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol
1481
            )
1482
            self.assertEqual(
1483
                i.grad.data,
1484
                torch.cat([i1.grad.data, i2.grad.data], 1),
1485
                atol=dtype2prec_DONTUSE[dtype],
1486
                rtol=0,
1487
            )
1488
            self.assertEqual(
1489
                m.bias.grad.data,
1490
                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
1491
                atol=dtype2prec_DONTUSE[dtype],
1492
                rtol=0,
1493
            )
1494
            self.assertEqual(
1495
                m.weight.grad.data,
1496
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
1497
                atol=atol,
1498
                rtol=rtol,
1499
            )
1500

1501
    @onlyCUDA
1502
    @dtypes(
1503
        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
1504
    )
1505
    def test_noncontig_conv_grad(self, device, dtype):
1506
        # FIXME: remove after adding non-contiguous grad tests for all modules
1507
        module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
1508
        input = torch.randn(
1509
            2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True
1510
        )
1511
        output = module(input)
1512

1513
        grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
1514
        assert not grad.is_contiguous()
1515
        output.backward(grad, retain_graph=True)
1516
        self.assertIsNotNone(input.grad)
1517
        result = input.grad.data.clone()
1518
        input.grad.data.zero_()
1519

1520
        output.backward(grad.contiguous())
1521
        self.assertEqual(
1522
            result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0
1523
        )
1524

1525
    @onlyCUDA
1526
    @dtypes(torch.double)
1527
    def test_conv_double_backward(self, device, dtype):
1528
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1529
            # Double backward only runs with DoubleTensor due to precision reason
1530
            batch_size = 1
1531
            for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
1532
                for stride, padding, chan_in, chan_out, dilation in product(
1533
                    [1], [2], [2], [3], dilations
1534
                ):
1535
                    no_weight = stride == 2
1536
                    result = self.run_conv_double_back_test(
1537
                        kern,
1538
                        stride,
1539
                        padding,
1540
                        chan_in,
1541
                        chan_out,
1542
                        batch_size,
1543
                        inp_size,
1544
                        dilation,
1545
                        no_weight,
1546
                        use_cuda=True,
1547
                        dtype=dtype,
1548
                    )
1549
                    self.assertTrue(
1550
                        result,
1551
                        "Conv double backward test failed with parameters:"
1552
                        + "\nkern: "
1553
                        + str(kern)
1554
                        + "\nstride: "
1555
                        + str(stride)
1556
                        + "\npadding: "
1557
                        + str(padding)
1558
                        + "\nchan_in: "
1559
                        + str(chan_in)
1560
                        + "\nchan_out: "
1561
                        + str(chan_out)
1562
                        + "\nbatch_size: "
1563
                        + str(batch_size)
1564
                        + "\ninp_size: "
1565
                        + str(inp_size)
1566
                        + "\ndilation: "
1567
                        + str(dilation),
1568
                    )
1569

1570
    def test_conv_double_backward_no_bias(self):
1571
        kern = 3
1572
        stride = 2
1573
        chan_in, chan_out = 2, 4
1574
        batch_size = 2
1575
        inp_size = 5
1576
        padding = 1
1577
        dilation = 1
1578
        no_weight = False
1579
        use_bias = True
1580
        result = self.run_conv_double_back_test(
1581
            kern,
1582
            stride,
1583
            padding,
1584
            chan_in,
1585
            chan_out,
1586
            batch_size,
1587
            inp_size,
1588
            dilation,
1589
            no_weight,
1590
            use_bias=use_bias,
1591
        )
1592
        self.assertTrue(
1593
            result,
1594
            "Conv double backward test failed with parameters:"
1595
            + "\nkern: "
1596
            + str(kern)
1597
            + "\nstride: "
1598
            + str(stride)
1599
            + "\npadding: "
1600
            + str(padding)
1601
            + "\nchan_in: "
1602
            + str(chan_in)
1603
            + "\nchan_out: "
1604
            + str(chan_out)
1605
            + "\nbatch_size: "
1606
            + str(batch_size)
1607
            + "\ninp_size: "
1608
            + str(inp_size)
1609
            + "\ndilation: "
1610
            + str(dilation),
1611
        )
1612

1613
    def test_conv_double_backward_groups(self):
1614
        kern = 3
1615
        stride = 1
1616
        padding = 2
1617
        chan_in, chan_out = 2, 4
1618
        batch_size = 2
1619
        inp_size = 6
1620
        dilation = 1
1621
        no_weight = False
1622
        groups = 2
1623
        result = self.run_conv_double_back_test(
1624
            kern,
1625
            stride,
1626
            padding,
1627
            chan_in * groups,
1628
            chan_out * groups,
1629
            batch_size,
1630
            inp_size,
1631
            dilation,
1632
            no_weight,
1633
            groups=groups,
1634
        )
1635
        self.assertTrue(
1636
            result,
1637
            "Conv double backward test failed with parameters:"
1638
            + "\nkern: "
1639
            + str(kern)
1640
            + "\nstride: "
1641
            + str(stride)
1642
            + "\npadding: "
1643
            + str(padding)
1644
            + "\nchan_in: "
1645
            + str(chan_in)
1646
            + "\nchan_out: "
1647
            + str(chan_out)
1648
            + "\nbatch_size: "
1649
            + str(batch_size)
1650
            + "\ninp_size: "
1651
            + str(inp_size)
1652
            + "\ndilation: "
1653
            + str(dilation)
1654
            + "\ngroups: "
1655
            + str(groups),
1656
        )
1657

1658
    def test_conv_double_backward_stride(self):
1659
        batch_size = 2
1660

1661
        # Cannot provide ggW when stride is > 1
1662
        for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
1663
            for stride, padding, chan_in, chan_out, dilation in product(
1664
                [2], [0, 1], [1], [2], dilations
1665
            ):
1666
                no_weight = False
1667
                self.run_conv_double_back_test(
1668
                    kern,
1669
                    stride,
1670
                    padding,
1671
                    chan_in,
1672
                    chan_out,
1673
                    batch_size,
1674
                    inp_size,
1675
                    dilation,
1676
                    no_weight,
1677
                )
1678

1679
    @dtypes(torch.float, torch.cfloat)
1680
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
1681
    def test_conv1d_same_padding(self, device, dtype):
1682
        # Test padding='same' outputs the correct shape
1683
        test_args = [
1684
            # in_size
1685
            range(50, 55),
1686
            # kernel_size
1687
            [1, 2, 3, 8],
1688
            # dilation
1689
            range(1, 4),
1690
            # stride
1691
            [1],
1692
        ]
1693
        for in_size, k_size, dilation, stride in itertools.product(*test_args):
1694
            x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
1695
            y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
1696
            z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride)
1697
            self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))
1698

1699
        # Compare F.conv1d padding='same' output against manual padding
1700
        # Without strides/dilation
1701
        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
1702
        y = torch.rand(1, 1, 3, device=device, dtype=dtype)
1703
        expect = F.conv1d(x, y, padding=1)
1704
        actual = F.conv1d(x, y, padding="same")
1705
        self.assertEqual(expect, actual)
1706

1707
        # With dilation
1708
        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
1709
        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
1710
        expect = F.conv1d(x, y, padding=3, dilation=2)
1711
        actual = F.conv1d(x, y, padding="same", dilation=2)
1712
        self.assertEqual(expect, actual)
1713

1714
        # Dilation with asymmetric padding
1715
        expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
1716
        actual = F.conv1d(x, y, padding="same", dilation=3)
1717
        self.assertEqual(expect, actual)
1718

1719
    @dtypes(torch.float, torch.cfloat)
1720
    def test_conv2d_same_padding(self, device, dtype):
1721
        if dtype is torch.cfloat:
1722
            rtol, atol = 2e-6, 2e-6
1723
        else:
1724
            rtol, atol = None, None
1725
        # Compare F.conv2d padding='same' output against manual padding
1726
        # Without strides/dilation
1727
        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype)
1728
        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype)
1729
        expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
1730
        actual = F.conv2d(x, y, padding="same")
1731
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1732

1733
        # With dilation
1734
        y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
1735
        expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
1736
        actual = F.conv2d(x, y, padding="same", dilation=2)
1737
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1738

1739
        # Dilation with asymmetric padding
1740
        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype)
1741
        expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
1742
        actual = F.conv2d(x, y, padding="same", dilation=3)
1743
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1744

1745
    @dtypes(torch.float, torch.cfloat)
1746
    def test_conv3d_same_padding(self, device, dtype):
1747
        if dtype is torch.cfloat:
1748
            rtol, atol = 2e-6, 2e-6
1749
        else:
1750
            rtol, atol = None, None
1751
        # Compare F.conv3d padding='same' output against manual padding
1752
        # Without strides/dilation
1753
        x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
1754
        y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
1755
        expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
1756
        actual = F.conv3d(x, y, padding="same")
1757
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1758

1759
        # With dilation
1760
        expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
1761
        actual = F.conv3d(x, y, padding="same", dilation=2)
1762
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1763

1764
        # Dilation with asymmetric padding
1765
        y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
1766
        expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
1767
        actual = F.conv3d(x, y, padding="same", dilation=3)
1768
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
1769

1770
    @dtypes(torch.float, torch.cfloat)
1771
    def test_conv1d_valid_padding(self, device, dtype):
1772
        # Test F.conv1d padding='valid' is the same as no padding
1773
        x = torch.rand(1, 1, 10, device=device, dtype=dtype)
1774
        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
1775
        expect = F.conv1d(x, y)
1776
        actual = F.conv1d(x, y, padding="valid")
1777
        self.assertEqual(expect, actual)
1778

1779
    @dtypes(torch.float, torch.cfloat)
1780
    def test_conv2d_valid_padding(self, device, dtype):
1781
        # Test F.conv2d padding='valid' is the same as no padding
1782
        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
1783
        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
1784
        expect = F.conv2d(x, y)
1785
        actual = F.conv2d(x, y, padding="valid")
1786
        self.assertEqual(expect, actual)
1787

1788
    @dtypes(torch.float, torch.cfloat)
1789
    def test_conv3d_valid_padding(self, device, dtype):
1790
        # Test F.conv3d padding='valid' is the same as no padding
1791
        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
1792
        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
1793
        expect = F.conv3d(x, y)
1794
        actual = F.conv3d(x, y, padding="valid")
1795
        self.assertEqual(expect, actual)
1796

1797
    @dtypes(torch.float, torch.cfloat)
1798
    def test_conv1d_same_padding_backward(self, device, dtype):
1799
        # Test F.conv1d gradients work with padding='same'
1800
        x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
1801
        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
1802

1803
        # Symmetric padding
1804
        z = F.conv1d(x, y, padding=3, dilation=2)
1805
        z.sum().abs().backward()
1806
        gx_expect, gy_expect = x.grad, y.grad
1807
        x.grad, y.grad = None, None
1808

1809
        z = F.conv1d(x, y, padding="same", dilation=2)
1810
        z.sum().abs().backward()
1811
        self.assertEqual(gx_expect, x.grad)
1812
        self.assertEqual(gy_expect, y.grad)
1813
        x.grad, y.grad = None, None
1814

1815
        # Asymmetric padding
1816
        z = F.conv1d(x, y, padding=2)[..., 1:]
1817
        z.sum().abs().backward()
1818
        gx_expect, gy_expect = x.grad, y.grad
1819
        x.grad, y.grad = None, None
1820

1821
        z = F.conv1d(x, y, padding="same")
1822
        z.sum().abs().backward()
1823
        self.assertEqual(gx_expect, x.grad)
1824
        self.assertEqual(gy_expect, y.grad)
1825

1826
    @dtypes(torch.float, torch.cfloat)
1827
    @tf32_on_and_off(0.001)
1828
    def test_conv2d_same_padding_backward(self, device, dtype):
1829
        # Test F.conv2d gradients work with padding='same'
1830
        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
1831
        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)
1832

1833
        # Symmetric padding
1834
        z = F.conv2d(x, y, padding=(3, 4), dilation=2)
1835
        z.sum().abs().backward()
1836
        gx_expect, gy_expect = x.grad, y.grad
1837
        x.grad, y.grad = None, None
1838

1839
        z = F.conv2d(x, y, padding="same", dilation=2)
1840
        z.sum().abs().backward()
1841
        self.assertEqual(gx_expect, x.grad)
1842
        self.assertEqual(gy_expect, y.grad)
1843
        x.grad, y.grad = None, None
1844

1845
        # Asymmetric padding
1846
        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
1847
        z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
1848
        z.sum().abs().backward()
1849
        gx_expect, gy_expect = x.grad, y.grad
1850
        x.grad, y.grad = None, None
1851

1852
        z = F.conv2d(x, y, padding="same")
1853
        z.sum().abs().backward()
1854
        self.assertEqual(gx_expect, x.grad)
1855
        self.assertEqual(gy_expect, y.grad)
1856

1857
    @dtypes(torch.double, torch.cdouble)
1858
    def test_conv3d_same_padding_backward(self, device, dtype):
1859
        check_forward_ad = torch.device(device).type != "xla"
1860

1861
        # Test F.conv3d gradients work with padding='same'
1862
        x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
1863
        y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)
1864

1865
        # Symmetric padding
1866
        z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
1867
        z.sum().abs().backward()
1868
        gx_expect, gy_expect = x.grad, y.grad
1869
        x.grad, y.grad = None, None
1870

1871
        z = F.conv3d(x, y, padding="same", dilation=2)
1872
        z.sum().abs().backward()
1873
        self.assertEqual(gx_expect, x.grad)
1874
        self.assertEqual(gy_expect, y.grad)
1875
        x.grad, y.grad = None, None
1876

1877
        gradcheck(
1878
            lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
1879
            (x, y),
1880
            check_forward_ad=check_forward_ad,
1881
            nondet_tol=1e-5,
1882
        )
1883
        if torch.device(device).type != "cuda":
1884
            # https://github.com/pytorch/pytorch/issues/70702
1885
            gradgradcheck(
1886
                lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
1887
                (x, y),
1888
                check_fwd_over_rev=True,
1889
            )
1890

1891
        # Asymmetric padding
1892
        y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
1893
        z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
1894
        z.sum().abs().backward()
1895
        gx_expect, gy_expect = x.grad, y.grad
1896
        x.grad, y.grad = None, None
1897

1898
        z = F.conv3d(x, y, padding="same")
1899
        z.sum().abs().backward()
1900
        self.assertEqual(gx_expect, x.grad)
1901
        self.assertEqual(gy_expect, y.grad)
1902

1903
        gradcheck(
1904
            lambda x, y: F.conv3d(x, y, padding="same"),
1905
            (x, y),
1906
            check_forward_ad=check_forward_ad,
1907
            nondet_tol=1e-5,
1908
        )
1909
        if torch.device(device).type != "cuda":
1910
            # https://github.com/pytorch/pytorch/issues/70702
1911
            gradgradcheck(
1912
                lambda x, y: F.conv3d(x, y, padding="same"),
1913
                (x, y),
1914
                check_fwd_over_rev=True,
1915
            )
1916

1917
    @dtypes(torch.float, torch.cfloat)
1918
    def test_conv1d_valid_padding_backward(self, device, dtype):
1919
        # Test F.conv1d gradients work with padding='valid'
1920
        x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
1921
        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
1922
        F.conv1d(x, y, padding=0).sum().abs().backward()
1923
        gx_expect, gy_expect = x.grad, y.grad
1924
        x.grad, y.grad = None, None
1925

1926
        F.conv1d(x, y, padding="valid").sum().abs().backward()
1927
        gx_actual, gy_actual = x.grad, y.grad
1928
        self.assertEqual(gx_expect, gx_actual)
1929
        self.assertEqual(gy_expect, gy_actual)
1930

1931
    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
1932
    @dtypes(torch.float, torch.cfloat)
1933
    @parametrize_test("mode", ("valid", "same"))
1934
    def test_conv1d_vs_scipy(self, device, dtype, mode):
1935
        t = make_tensor((1, 10), device=device, dtype=dtype)
1936
        feat_dim = t.shape[1]
1937
        weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
1938
        weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)
1939

1940
        def _test(t, weight, mode):
1941
            # SciPy expects two 1-D inputs.
1942
            t_a = t.view(-1).cpu().numpy()
1943
            w_a = weight.view(-1).cpu().numpy()
1944
            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
1945

1946
            kwargs = {"padding": mode}
1947
            if mode == "same":
1948
                # `same` padding in PyTorch conv1d is different
1949
                # from SciPy
1950
                p = weight.shape[2] // 2
1951
                t = torch.nn.functional.pad(t, (p, p))
1952
                # We have already taken care of padding
1953
                kwargs.pop("padding")
1954

1955
            # second input is flipped in SciPy's convolve
1956
            weight_flipped = torch.flip(weight, (2,))
1957
            actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
1958
            if mode == "same":
1959
                actual = actual[:feat_dim]
1960

1961
            self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)
1962

1963
        # Global dtype for this test suite is torch.double
1964
        # This leads to change in type-promotion
1965
        # and conv1d outputs `complex128` for `complex64` input.
1966
        with set_default_dtype(torch.float):
1967
            _test(t, weight_even, mode)
1968
            _test(t, weight_odd, mode)
1969

1970
    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
1971
    @dtypes(torch.float, torch.cfloat)
1972
    @parametrize_test("mode", ("valid", "same"))
1973
    def test_conv2d_vs_scipy(self, device, dtype, mode):
1974
        t = make_tensor((1, 5, 10), device=device, dtype=dtype)
1975
        weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
1976
        weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)
1977

1978
        def _test(t, weight, mode):
1979
            # SciPy expects two 2-D inputs.
1980
            t_a = t.squeeze(0).cpu().numpy()
1981
            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
1982
            expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)
1983

1984
            kwargs = {"padding": mode}
1985
            if mode == "same":
1986
                # `same` padding in PyTorch conv2d is different
1987
                # from SciPy
1988
                left_right_pad = weight.shape[3] // 2
1989
                top_bottom_pad = weight.shape[2] // 2
1990
                p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
1991
                t = torch.nn.functional.pad(t, p)
1992
                # We have already taken care of padding
1993
                kwargs.pop("padding")
1994

1995
            # second input is flipped in SciPy's convolve2d
1996
            weight_flipped = torch.flip(weight, (2, 3))
1997
            actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
1998
            if mode == "same":
1999
                actual = actual[:5, :10]
2000

2001
            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
2002

2003
        # Global dtype for this test suite is torch.double
2004
        # This leads to change in type-promotion
2005
        # and conv1d outputs `complex128` for `complex64` input.
2006
        with set_default_dtype(torch.float):
2007
            _test(t, weight_even, mode)
2008
            _test(t, weight_odd, mode)
2009

2010
    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
2011
    @dtypes(torch.float, torch.cfloat)
2012
    @parametrize_test("mode", ("valid", "same"))
2013
    def test_conv3d_vs_scipy(self, device, dtype, mode):
2014
        t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
2015
        weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
2016
        weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)
2017

2018
        def _test(t, weight, mode):
2019
            # SciPy expects two 3-D inputs.
2020
            t_a = t.squeeze(0).cpu().numpy()
2021
            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
2022
            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
2023

2024
            kwargs = {"padding": mode}
2025
            if mode == "same":
2026
                # `same` padding in PyTorch conv3d is different
2027
                # from SciPy
2028
                left_right_pad = weight.shape[4] // 2
2029
                top_bottom_pad = weight.shape[3] // 2
2030
                front_back_pad = weight.shape[2] // 2
2031
                p = (
2032
                    left_right_pad,
2033
                    left_right_pad,
2034
                    top_bottom_pad,
2035
                    top_bottom_pad,
2036
                    front_back_pad,
2037
                    front_back_pad,
2038
                )
2039
                t = torch.nn.functional.pad(t, p)
2040
                # We have already taken care of padding
2041
                kwargs.pop("padding")
2042

2043
            # second input is flipped in SciPy's convolve
2044
            weight_flipped = torch.flip(weight, (2, 3, 4))
2045
            actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
2046
            if mode == "same":
2047
                actual = actual[:5, :5, :10]
2048

2049
            if tf32_is_not_fp32() and (
2050
                dtype == torch.float or dtype == torch.complex64
2051
            ):
2052
                self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
2053
            else:
2054
                self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
2055

2056
        # Global dtype for this test suite is torch.double
2057
        # This leads to change in type-promotion
2058
        # and conv1d outputs `complex128` for `complex64` input.
2059
        with set_default_dtype(torch.float):
2060
            _test(t, weight_even, mode)
2061
            _test(t, weight_odd, mode)
2062

2063
    @dtypes(torch.float, torch.complex64)
2064
    def test_conv2d_valid_padding_backward(self, device, dtype):
2065
        # Test F.conv2d gradients work with padding='valid'
2066
        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
2067
        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
2068
        F.conv2d(x, y, padding=0).sum().abs().backward()
2069
        gx_expect, gy_expect = x.grad, y.grad
2070
        x.grad, y.grad = None, None
2071

2072
        F.conv2d(x, y, padding="valid").sum().abs().backward()
2073
        gx_actual, gy_actual = x.grad, y.grad
2074
        self.assertEqual(gx_expect, gx_actual)
2075
        self.assertEqual(gy_expect, gy_actual)
2076

2077
    @dtypes(torch.double, torch.cdouble)
2078
    def test_conv3d_valid_padding_backward(self, device, dtype):
2079
        check_forward_ad = torch.device(device).type != "xla"
2080

2081
        # Test F.conv3d gradients work with padding='valid'
2082
        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
2083
        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
2084
        F.conv3d(x, y, padding=0).sum().abs().backward()
2085
        gx_expect, gy_expect = x.grad, y.grad
2086
        x.grad, y.grad = None, None
2087

2088
        F.conv3d(x, y, padding="valid").sum().abs().backward()
2089
        gx_actual, gy_actual = x.grad, y.grad
2090
        self.assertEqual(gx_expect, gx_actual)
2091
        self.assertEqual(gy_expect, gy_actual)
2092

2093
        gradcheck(
2094
            lambda x, y: F.conv3d(x, y, padding="valid"),
2095
            (x, y),
2096
            check_forward_ad=check_forward_ad,
2097
        )
2098
        gradgradcheck(
2099
            lambda x, y: F.conv3d(x, y, padding="valid"),
2100
            (x, y),
2101
            check_fwd_over_rev=check_forward_ad,
2102
        )
2103

2104
    @parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d")
2105
    def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
2106
        # For inputs with no batch dim, verify output is the correct shape when output_size is set.
2107
        # See https://github.com/pytorch/pytorch/issues/75889
2108
        inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
2109
        output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
2110
        ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d")
2111
        m = ConvTransposeNd(
2112
            1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device
2113
        )
2114
        output = m(inp, output_size=output_size)
2115
        self.assertEqual(output.shape, output_size)
2116

2117
    @skipMeta
2118
    @parametrize_test(
2119
        "input_shape,transposed,dilated,groups,layout,backend_expected",
2120
        [
2121
            # === slow ===
2122
            subtest(
2123
                (
2124
                    (2, 6, 7),
2125
                    False,
2126
                    False,
2127
                    3,
2128
                    torch.strided,
2129
                    torch._C._ConvBackend.Slow2d,
2130
                ),
2131
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2132
                name="slow1d",
2133
            ),
2134
            subtest(
2135
                (
2136
                    (2, 6, 7),
2137
                    True,
2138
                    False,
2139
                    3,
2140
                    torch.strided,
2141
                    torch._C._ConvBackend.SlowTranspose2d,
2142
                ),
2143
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2144
                name="slow1d_transposed",
2145
            ),
2146
            subtest(
2147
                (
2148
                    (2, 6, 7),
2149
                    False,
2150
                    True,
2151
                    3,
2152
                    torch.strided,
2153
                    torch._C._ConvBackend.SlowDilated2d,
2154
                ),
2155
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2156
                name="slow1d_dilated",
2157
            ),
2158
            subtest(
2159
                (
2160
                    (2, 6, 7),
2161
                    True,
2162
                    True,
2163
                    3,
2164
                    torch.strided,
2165
                    torch._C._ConvBackend.SlowTranspose2d,
2166
                ),
2167
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2168
                name="slow1d_dilated_transposed",
2169
            ),
2170
            subtest(
2171
                (
2172
                    (2, 6, 7, 8),
2173
                    False,
2174
                    False,
2175
                    3,
2176
                    torch.strided,
2177
                    torch._C._ConvBackend.Slow2d,
2178
                ),
2179
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2180
                name="slow2d",
2181
            ),
2182
            subtest(
2183
                (
2184
                    (2, 6, 7, 8),
2185
                    True,
2186
                    False,
2187
                    3,
2188
                    torch.strided,
2189
                    torch._C._ConvBackend.SlowTranspose2d,
2190
                ),
2191
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2192
                name="slow2d_transposed",
2193
            ),
2194
            subtest(
2195
                (
2196
                    (2, 6, 7, 8),
2197
                    False,
2198
                    True,
2199
                    3,
2200
                    torch.strided,
2201
                    torch._C._ConvBackend.SlowDilated2d,
2202
                ),
2203
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2204
                name="slow2d_dilated",
2205
            ),
2206
            subtest(
2207
                (
2208
                    (2, 6, 7, 8),
2209
                    True,
2210
                    True,
2211
                    3,
2212
                    torch.strided,
2213
                    torch._C._ConvBackend.SlowTranspose2d,
2214
                ),
2215
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2216
                name="slow2d_dilated_transposed",
2217
            ),
2218
            subtest(
2219
                (
2220
                    (2, 6, 7, 8, 9),
2221
                    False,
2222
                    False,
2223
                    3,
2224
                    torch.strided,
2225
                    torch._C._ConvBackend.Slow3d,
2226
                ),
2227
                decorators=[onlyCPU, disableMkldnn],
2228
                name="slow3d_cpu",
2229
            ),
2230
            # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead
2231
            subtest(
2232
                (
2233
                    (2, 6, 7, 8, 9),
2234
                    False,
2235
                    False,
2236
                    3,
2237
                    torch.strided,
2238
                    torch._C._ConvBackend.SlowDilated3d,
2239
                ),
2240
                decorators=[onlyCUDA, disablecuDNN],
2241
                name="slow3d_cuda",
2242
            ),
2243
            # FIXME: RuntimeError: CUDA out of memory.
2244
            # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
2245
            #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'),
2246
            subtest(
2247
                (
2248
                    (2, 6, 7, 8, 9),
2249
                    False,
2250
                    True,
2251
                    3,
2252
                    torch.strided,
2253
                    torch._C._ConvBackend.SlowDilated3d,
2254
                ),
2255
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
2256
                name="slow3d_dilated",
2257
            ),
2258
            # FIXME: RuntimeError: CUDA out of memory.
2259
            # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
2260
            #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'),
2261
            subtest(
2262
                (
2263
                    (0, 6, 7),
2264
                    False,
2265
                    False,
2266
                    3,
2267
                    torch.strided,
2268
                    torch._C._ConvBackend.Empty,
2269
                ),
2270
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2271
                name="empty_batch1d",
2272
            ),
2273
            subtest(
2274
                (
2275
                    (2, 0, 7),
2276
                    False,
2277
                    False,
2278
                    3,
2279
                    torch.strided,
2280
                    torch._C._ConvBackend.Empty,
2281
                ),
2282
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2283
                name="empty_channel1d",
2284
            ),
2285
            subtest(
2286
                (
2287
                    (0, 0, 7),
2288
                    False,
2289
                    False,
2290
                    3,
2291
                    torch.strided,
2292
                    torch._C._ConvBackend.Empty,
2293
                ),
2294
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2295
                name="empty_batch_channel1d",
2296
            ),
2297
            subtest(
2298
                (
2299
                    (0, 6, 7, 8),
2300
                    False,
2301
                    False,
2302
                    3,
2303
                    torch.strided,
2304
                    torch._C._ConvBackend.Empty,
2305
                ),
2306
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2307
                name="empty_batch2d",
2308
            ),
2309
            subtest(
2310
                (
2311
                    (2, 0, 7, 8),
2312
                    False,
2313
                    False,
2314
                    3,
2315
                    torch.strided,
2316
                    torch._C._ConvBackend.Empty,
2317
                ),
2318
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2319
                name="empty_channel2d",
2320
            ),
2321
            subtest(
2322
                (
2323
                    (0, 0, 7, 8),
2324
                    False,
2325
                    False,
2326
                    3,
2327
                    torch.strided,
2328
                    torch._C._ConvBackend.Empty,
2329
                ),
2330
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2331
                name="empty_batch_channel2d",
2332
            ),
2333
            subtest(
2334
                (
2335
                    (0, 6, 7, 8, 9),
2336
                    False,
2337
                    False,
2338
                    3,
2339
                    torch.strided,
2340
                    torch._C._ConvBackend.Empty,
2341
                ),
2342
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2343
                name="empty_batch3d",
2344
            ),
2345
            subtest(
2346
                (
2347
                    (2, 0, 7, 8, 9),
2348
                    False,
2349
                    False,
2350
                    3,
2351
                    torch.strided,
2352
                    torch._C._ConvBackend.Empty,
2353
                ),
2354
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2355
                name="empty_channel3d",
2356
            ),
2357
            subtest(
2358
                (
2359
                    (0, 0, 7, 8, 9),
2360
                    False,
2361
                    False,
2362
                    3,
2363
                    torch.strided,
2364
                    torch._C._ConvBackend.Empty,
2365
                ),
2366
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
2367
                name="empty_batch_channel3d",
2368
            ),
2369
            # === cuda ===
2370
            # Note that disablecuDNN disables miopen as well.
2371
            subtest(
2372
                (
2373
                    (2, 6, 7),
2374
                    False,
2375
                    False,
2376
                    6,
2377
                    torch.strided,
2378
                    torch._C._ConvBackend.CudaDepthwise2d,
2379
                ),
2380
                decorators=[onlyCUDA, disablecuDNN],
2381
                name="cuda_depthwise1d",
2382
            ),
2383
            subtest(
2384
                (
2385
                    (2, 6, 7, 8),
2386
                    False,
2387
                    False,
2388
                    6,
2389
                    torch.strided,
2390
                    torch._C._ConvBackend.CudaDepthwise2d,
2391
                ),
2392
                decorators=[onlyCUDA, disablecuDNN],
2393
                name="cuda_depthwise2d",
2394
            ),
2395
            subtest(
2396
                (
2397
                    (2, 6, 7, 8, 9),
2398
                    False,
2399
                    False,
2400
                    6,
2401
                    torch.strided,
2402
                    torch._C._ConvBackend.CudaDepthwise3d,
2403
                ),
2404
                decorators=[onlyCUDA, disablecuDNN],
2405
                name="cuda_depthwise3d",
2406
            ),
2407
            # === cudnn ===
2408
            subtest(
2409
                (
2410
                    (2, 6, 7),
2411
                    False,
2412
                    False,
2413
                    3,
2414
                    torch.strided,
2415
                    torch._C._ConvBackend.Cudnn,
2416
                ),
2417
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2418
                name="cudnn1d",
2419
            ),
2420
            subtest(
2421
                (
2422
                    (2, 6, 7, 8),
2423
                    False,
2424
                    False,
2425
                    3,
2426
                    torch.strided,
2427
                    torch._C._ConvBackend.Cudnn,
2428
                ),
2429
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2430
                name="cudnn2d",
2431
            ),
2432
            subtest(
2433
                (
2434
                    (2, 6, 7, 8, 9),
2435
                    False,
2436
                    False,
2437
                    3,
2438
                    torch.strided,
2439
                    torch._C._ConvBackend.Cudnn,
2440
                ),
2441
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2442
                name="cudnn3d",
2443
            ),
2444
            subtest(
2445
                (
2446
                    (2, 6, 7),
2447
                    True,
2448
                    False,
2449
                    3,
2450
                    torch.strided,
2451
                    torch._C._ConvBackend.CudnnTranspose,
2452
                ),
2453
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2454
                name="cudnn1d_transposed",
2455
            ),
2456
            subtest(
2457
                (
2458
                    (2, 6, 7, 8),
2459
                    True,
2460
                    False,
2461
                    3,
2462
                    torch.strided,
2463
                    torch._C._ConvBackend.CudnnTranspose,
2464
                ),
2465
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
2466
                name="cudnn2d_transposed",
2467
            ),
2468
            # FIXME: RuntimeError: CUDA out of memory.
2469
            # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
2470
            #         decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'),
2471
            # === miopen ===
2472
            subtest(
2473
                (
2474
                    (2, 6, 7),
2475
                    False,
2476
                    False,
2477
                    3,
2478
                    torch.strided,
2479
                    torch._C._ConvBackend.Miopen,
2480
                ),
2481
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2482
                name="miopen1d",
2483
            ),
2484
            subtest(
2485
                (
2486
                    (2, 6, 7, 8),
2487
                    False,
2488
                    False,
2489
                    3,
2490
                    torch.strided,
2491
                    torch._C._ConvBackend.Miopen,
2492
                ),
2493
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2494
                name="miopen2d",
2495
            ),
2496
            subtest(
2497
                (
2498
                    (2, 6, 7, 8, 9),
2499
                    False,
2500
                    False,
2501
                    3,
2502
                    torch.strided,
2503
                    torch._C._ConvBackend.Miopen,
2504
                ),
2505
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2506
                name="miopen3d",
2507
            ),
2508
            subtest(
2509
                (
2510
                    (2, 6, 7),
2511
                    True,
2512
                    False,
2513
                    3,
2514
                    torch.strided,
2515
                    torch._C._ConvBackend.MiopenTranspose,
2516
                ),
2517
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2518
                name="miopen1d_transposed",
2519
            ),
2520
            subtest(
2521
                (
2522
                    (2, 6, 7, 8),
2523
                    True,
2524
                    False,
2525
                    3,
2526
                    torch.strided,
2527
                    torch._C._ConvBackend.MiopenTranspose,
2528
                ),
2529
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2530
                name="miopen2d_transposed",
2531
            ),
2532
            subtest(
2533
                (
2534
                    (2, 6, 7, 8, 9),
2535
                    True,
2536
                    False,
2537
                    3,
2538
                    torch.strided,
2539
                    torch._C._ConvBackend.MiopenTranspose,
2540
                ),
2541
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2542
                name="miopen3d_transposed",
2543
            ),
2544
            subtest(
2545
                (
2546
                    (2, 6, 7),
2547
                    False,
2548
                    False,
2549
                    6,
2550
                    torch.strided,
2551
                    torch._C._ConvBackend.MiopenDepthwise,
2552
                ),
2553
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2554
                name="miopen_depthwise1d",
2555
            ),
2556
            subtest(
2557
                (
2558
                    (2, 6, 7, 8),
2559
                    False,
2560
                    False,
2561
                    6,
2562
                    torch.strided,
2563
                    torch._C._ConvBackend.MiopenDepthwise,
2564
                ),
2565
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2566
                name="miopen_depthwise2d",
2567
            ),
2568
            subtest(
2569
                (
2570
                    (2, 6, 7, 8, 9),
2571
                    False,
2572
                    False,
2573
                    6,
2574
                    torch.strided,
2575
                    torch._C._ConvBackend.MiopenDepthwise,
2576
                ),
2577
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
2578
                name="miopen_depthwise3d",
2579
            ),
2580
            # === mkldnn ===
2581
            subtest(
2582
                (
2583
                    (2, 6, 7),
2584
                    False,
2585
                    False,
2586
                    3,
2587
                    torch._mkldnn,
2588
                    torch._C._ConvBackend.Mkldnn,
2589
                ),
2590
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2591
                name="mkldnn1d",
2592
            ),
2593
            subtest(
2594
                (
2595
                    (2, 6, 7, 8),
2596
                    False,
2597
                    False,
2598
                    3,
2599
                    torch._mkldnn,
2600
                    torch._C._ConvBackend.Mkldnn,
2601
                ),
2602
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2603
                name="mkldnn2d",
2604
            ),
2605
            subtest(
2606
                (
2607
                    (2, 6, 7, 8, 9),
2608
                    False,
2609
                    False,
2610
                    3,
2611
                    torch._mkldnn,
2612
                    torch._C._ConvBackend.Mkldnn,
2613
                ),
2614
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2615
                name="mkldnn3d",
2616
            ),
2617
            # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775.
2618
            subtest(
2619
                (
2620
                    (2, 6, 7),
2621
                    True,
2622
                    False,
2623
                    3,
2624
                    torch._mkldnn,
2625
                    torch._C._ConvBackend.Mkldnn,
2626
                ),
2627
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
2628
                name="mkldnn1d_transposed",
2629
            ),
2630
            subtest(
2631
                (
2632
                    (2, 6, 7, 8),
2633
                    True,
2634
                    False,
2635
                    3,
2636
                    torch._mkldnn,
2637
                    torch._C._ConvBackend.Mkldnn,
2638
                ),
2639
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
2640
                name="mkldnn2d_transposed",
2641
            ),
2642
            subtest(
2643
                (
2644
                    (2, 6, 7, 8, 9),
2645
                    True,
2646
                    False,
2647
                    3,
2648
                    torch._mkldnn,
2649
                    torch._C._ConvBackend.Mkldnn,
2650
                ),
2651
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
2652
                name="mkldnn3d_transposed",
2653
            ),
2654
            subtest(
2655
                (
2656
                    (2, 6, 7),
2657
                    False,
2658
                    True,
2659
                    3,
2660
                    torch.strided,
2661
                    torch._C._ConvBackend.Mkldnn,
2662
                ),
2663
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2664
                name="mkldnn1d_cpu_input",
2665
            ),
2666
            subtest(
2667
                (
2668
                    (2, 6, 7, 8),
2669
                    False,
2670
                    True,
2671
                    3,
2672
                    torch.strided,
2673
                    torch._C._ConvBackend.Mkldnn,
2674
                ),
2675
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2676
                name="mkldnn2d_cpu_input",
2677
            ),
2678
            subtest(
2679
                (
2680
                    (2, 6, 7, 8, 9),
2681
                    False,
2682
                    True,
2683
                    3,
2684
                    torch.strided,
2685
                    torch._C._ConvBackend.Mkldnn,
2686
                ),
2687
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2688
                name="mkldnn3d_cpu_input",
2689
            ),
2690
            subtest(
2691
                (
2692
                    (0, 6, 7),
2693
                    False,
2694
                    False,
2695
                    3,
2696
                    torch._mkldnn,
2697
                    torch._C._ConvBackend.MkldnnEmpty,
2698
                ),
2699
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2700
                name="mkldnn_empty_batch1d",
2701
            ),
2702
            subtest(
2703
                (
2704
                    (2, 0, 7),
2705
                    False,
2706
                    False,
2707
                    3,
2708
                    torch._mkldnn,
2709
                    torch._C._ConvBackend.MkldnnEmpty,
2710
                ),
2711
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2712
                name="mkldnn_empty_channel1d",
2713
            ),
2714
            subtest(
2715
                (
2716
                    (0, 0, 7),
2717
                    False,
2718
                    False,
2719
                    3,
2720
                    torch._mkldnn,
2721
                    torch._C._ConvBackend.MkldnnEmpty,
2722
                ),
2723
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2724
                name="mkldnn_empty_batch_channel1d",
2725
            ),
2726
            subtest(
2727
                (
2728
                    (0, 6, 7, 8),
2729
                    False,
2730
                    False,
2731
                    3,
2732
                    torch._mkldnn,
2733
                    torch._C._ConvBackend.MkldnnEmpty,
2734
                ),
2735
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2736
                name="mkldnn_empty_batch2d",
2737
            ),
2738
            subtest(
2739
                (
2740
                    (2, 0, 7, 8),
2741
                    False,
2742
                    False,
2743
                    3,
2744
                    torch._mkldnn,
2745
                    torch._C._ConvBackend.MkldnnEmpty,
2746
                ),
2747
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2748
                name="mkldnn_empty_channel2d",
2749
            ),
2750
            subtest(
2751
                (
2752
                    (0, 0, 7, 8),
2753
                    False,
2754
                    False,
2755
                    3,
2756
                    torch._mkldnn,
2757
                    torch._C._ConvBackend.MkldnnEmpty,
2758
                ),
2759
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2760
                name="mkldnn_empty_batch_channel2d",
2761
            ),
2762
            subtest(
2763
                (
2764
                    (0, 6, 7, 8, 9),
2765
                    False,
2766
                    False,
2767
                    3,
2768
                    torch._mkldnn,
2769
                    torch._C._ConvBackend.MkldnnEmpty,
2770
                ),
2771
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2772
                name="mkldnn_empty_batch3d",
2773
            ),
2774
            subtest(
2775
                (
2776
                    (2, 0, 7, 8, 9),
2777
                    False,
2778
                    False,
2779
                    3,
2780
                    torch._mkldnn,
2781
                    torch._C._ConvBackend.MkldnnEmpty,
2782
                ),
2783
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2784
                name="mkldnn_empty_channel3d",
2785
            ),
2786
            subtest(
2787
                (
2788
                    (0, 0, 7, 8, 9),
2789
                    False,
2790
                    False,
2791
                    3,
2792
                    torch._mkldnn,
2793
                    torch._C._ConvBackend.MkldnnEmpty,
2794
                ),
2795
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
2796
                name="mkldnn_empty_batch_channel3d",
2797
            ),
2798
            # Note: Tests for mobile backends are not currently supported. This comprises
2799
            # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these
2800
            # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1.
2801
        ],
2802
    )
2803
    # Test with both bias and no bias.
2804
    @parametrize_test("has_bias", [False, True])
2805
    # Test with both stride=1 and stride>1 cases.
2806
    @parametrize_test("strided", [False, True])
2807
    # Test with both contiguous and non-contiguous inputs.
2808
    @parametrize_test("contiguous", [False, True])
2809
    def test_conv_backend(
2810
        self,
2811
        device,
2812
        input_shape,
2813
        has_bias,
2814
        strided,
2815
        contiguous,
2816
        transposed,
2817
        dilated,
2818
        groups,
2819
        layout,
2820
        backend_expected,
2821
    ):
2822
        # Build up inputs.
2823
        dtype = torch.float32
2824
        C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3
2825
        x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True)
2826
        weight = torch.randn(
2827
            C_in if transposed else C_out,
2828
            C_out // groups if transposed else C_in // groups,
2829
            *[kernel_size for _ in range(dim)],
2830
            device=device,
2831
            dtype=dtype,
2832
            requires_grad=True,
2833
        )
2834
        bias = (
2835
            torch.randn(C_out, device=device, dtype=dtype, requires_grad=True)
2836
            if has_bias
2837
            else None
2838
        )
2839

2840
        def _make_noncontiguous(inp):
2841
            if inp is None:
2842
                return None
2843
            old_requires_grad = inp.requires_grad
2844
            inp = torch.repeat_interleave(inp, 2, dim=-1)
2845
            inp = inp[..., ::2].detach().requires_grad_(old_requires_grad)
2846
            return inp
2847

2848
        if not contiguous:
2849
            x = _make_noncontiguous(x)
2850
            weight = _make_noncontiguous(weight)
2851
            bias = _make_noncontiguous(bias)
2852

2853
        if layout is torch._mkldnn:
2854
            x = x.to_mkldnn()
2855
            # Note that weight and bias are not supported as mkldnn tensors during training.
2856

2857
        stride = (2,) * dim if strided else (1,) * dim
2858
        padding = (0,) * dim
2859
        dilation = (2,) * dim if dilated else (1,) * dim
2860
        output_padding = (0,) * dim
2861
        inputs = [
2862
            x,
2863
            weight,
2864
            bias,
2865
            stride,
2866
            padding,
2867
            dilation,
2868
            transposed,
2869
            output_padding,
2870
            groups,
2871
        ]
2872

2873
        # Ensure correct backend is selected.
2874
        backend_actual = torch._C._select_conv_backend(*inputs)
2875
        self.assertEqual(backend_actual, backend_expected)
2876

2877
        # Ensure backward call succeeds.
2878
        convolution = torch.ops.aten.convolution
2879
        output = convolution(*inputs)
2880
        grad_output = torch.randn(output.shape, device=device, dtype=dtype)
2881
        if not contiguous:
2882
            grad_output = _make_noncontiguous(grad_output)
2883
        if layout is torch._mkldnn:
2884
            grad_output = grad_output.to_mkldnn()
2885
        output.backward(grad_output)
2886

2887
        # mkldnn doesn't support gradcheck :(
2888
        if layout is torch._mkldnn:
2889
            return
2890

2891
        if backend_actual != torch._C._ConvBackend.Empty:  # FIXME: forward AD fails
2892
            # Forward AD and forward-over-reverse AD smoke test in float32
2893
            # TODO: remove this if we introduce per-op gradient tests for float32
2894
            with fwAD.dual_level():
2895
                dual_inputs = [
2896
                    (
2897
                        fwAD.make_dual(i, torch.rand_like(i))
2898
                        if isinstance(i, torch.Tensor)
2899
                        else i
2900
                    )
2901
                    for i in inputs
2902
                ]
2903
                # Forward AD
2904
                output = convolution(*dual_inputs)
2905
                # Forward over reverse AD
2906
                grad_output_d = fwAD.make_dual(
2907
                    torch.rand_like(output), torch.rand_like(output)
2908
                )
2909
                if has_bias:
2910
                    torch.autograd.grad(output, [x, weight, bias], grad_output_d)
2911
                else:
2912
                    torch.autograd.grad(output, [x, weight], grad_output_d)
2913

2914
        # Convert to float64 for gradcheck.
2915
        x = x.to(torch.float64).detach().requires_grad_(True)
2916
        weight = weight.to(torch.float64).detach().requires_grad_(True)
2917
        if bias is not None:
2918
            bias = bias.to(torch.float64).detach().requires_grad_(True)
2919
        inputs = [
2920
            x,
2921
            weight,
2922
            bias,
2923
            stride,
2924
            padding,
2925
            dilation,
2926
            transposed,
2927
            output_padding,
2928
            groups,
2929
        ]
2930

2931
        # Set some backend-specific validation settings.
2932
        gradcheck_nondet_tol = 0.0
2933
        if torch.backends.cudnn.is_available():
2934
            # cuDNN introduces non-determinism
2935
            gradcheck_nondet_tol = GRADCHECK_NONDET_TOL
2936

2937
        self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))
2938

2939
        # double backward doesn't support bias gradients
2940
        if bias is not None:
2941
            bias.requires_grad_(False)
2942
        self.assertTrue(
2943
            gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)
2944
        )
2945

2946
    @onlyCPU
2947
    def test_conv_contiguous_for_oneDNN(self):
2948
        # See https://github.com/pytorch/pytorch/issues/80837.
2949
        for dtype in [torch.float, torch.bfloat16, torch.half]:
2950
            conv = nn.Conv2d(
2951
                1,
2952
                128,
2953
                kernel_size=(5, 2),
2954
                stride=(2, 1),
2955
                padding=(0, 1),
2956
                dilation=(1, 1),
2957
                groups=1,
2958
                bias=True,
2959
                padding_mode="zeros",
2960
            ).to(dtype=dtype)
2961

2962
            x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype)
2963
            x = torch.transpose(x, 1, 4)
2964
            x2 = x[..., 0]
2965
            inputs = [
2966
                x2,
2967
                conv.weight,
2968
                conv.bias,
2969
                (2, 1),
2970
                (0, 1),
2971
                (1, 1),
2972
                False,
2973
                (0, 1),
2974
                1,
2975
            ]
2976
            if torch.backends.mkldnn.is_available():
2977
                y = conv(x2)
2978
                # Disable MKLDNN explicitly
2979
                with torch.backends.mkldnn.flags(enabled=False):
2980
                    y_ = conv(x2)
2981
                    self.assertEqual(y, y_)
2982

2983
    @onlyCPU
2984
    def test_conv_ic1_channels_last_for_oneDNN(self):
2985
        # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path.
2986
        for dtype in [torch.float, torch.bfloat16, torch.half]:
2987
            conv = torch.nn.Conv2d(
2988
                1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False
2989
            )
2990
            conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype)
2991
            x = torch.rand(2, 1, 100, 100).to(dtype=dtype)
2992
            if torch.backends.mkldnn.is_available():
2993
                y = conv(x)
2994
                # Disable MKLDNN explicitly
2995
                with torch.backends.mkldnn.flags(enabled=False):
2996
                    y_ = conv(x)
2997
                    self.assertEqual(y, y_)
2998

2999
    @dtypes(torch.float, torch.cfloat)
3000
    def test_conv_empty_channel(self, device, dtype):
3001
        in_channels = 0
3002
        mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
3003
        inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
3004
        _test_module_empty_input(self, mod, inp, check_size=False)
3005

3006
        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
3007
            inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
3008
            mod(inp)
3009

3010
        mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
3011
        inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
3012
        _test_module_empty_input(self, mod, inp, check_size=False)
3013

3014
        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
3015
            inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
3016
            mod(inp)
3017

3018
        mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
3019
        inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
3020
        _test_module_empty_input(self, mod, inp, check_size=False)
3021

3022
        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
3023
            inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
3024
            mod(inp)
3025

3026
    def test_group_conv_empty(self, device):
3027
        mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(
3028
            device
3029
        )
3030
        inp = torch.randn(0, 4, 4, 4, device=device)
3031
        _test_module_empty_input(self, mod, inp, check_size=False)
3032
        if self.device_type == "cuda" and self.has_cudnn():
3033
            with torch.backends.cudnn.flags(enabled=False):
3034
                _test_module_empty_input(self, mod, inp, check_size=False)
3035

3036
    def test_group_convTranspose_empty(self, device):
3037
        mod = torch.nn.ConvTranspose2d(
3038
            4, 4, stride=2, kernel_size=3, padding=1, groups=4
3039
        ).to(device)
3040
        inp = torch.randn(0, 4, 4, 4, device=device)
3041
        _test_module_empty_input(self, mod, inp, check_size=False)
3042
        if self.device_type == "cuda" and self.has_cudnn():
3043
            with torch.backends.cudnn.flags(enabled=False):
3044
                _test_module_empty_input(self, mod, inp, check_size=False)
3045

3046
    def test_convTranspose_empty(self, device):
3047
        mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(
3048
            device
3049
        )
3050
        inp = torch.randn(0, 4, 4, 4, device=device)
3051
        _test_module_empty_input(self, mod, inp, check_size=False)
3052
        if self.device_type == "cuda" and self.has_cudnn():
3053
            with torch.backends.cudnn.flags(enabled=False):
3054
                _test_module_empty_input(self, mod, inp, check_size=False)
3055

3056
    @onlyCUDA
3057
    @largeTensorTest("12GB")
3058
    def test_conv_large_nosplit(self, device):
3059
        # Here we just test the convolution correctly route to the fallback implementation
3060
        # that is, it does not crash. The correctness of fallback implementation should be
3061
        # covered in other tests
3062
        dtype = torch.half if self.device_type == "cuda" else torch.float
3063
        conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
3064
        input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
3065
        conv1(input_large)
3066
        conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
3067
        input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
3068
        conv2(input_large)
3069

3070
    def test_conv_noncontig_weights(self, device):
3071
        for dim in (1, 2, 3):
3072
            for grouped in (False, True):
3073
                nc = 3
3074
                groups = 3 if grouped else 1
3075
                w = torch.randn([3] * dim, device=device)
3076
                w = w.expand([nc, int(nc / groups)] + list(w.shape))
3077
                w = w.detach().requires_grad_()
3078
                x = torch.randn(
3079
                    [1, nc] + ([5] * dim), device=device, requires_grad=True
3080
                )
3081
                y = getattr(F, f"conv{dim}d")(x, w, groups=groups)
3082
                y.sum().backward()
3083
                y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups)
3084
                y.sum().backward()
3085

3086
    def test_conv_noncontig_weights_and_bias(self, device):
3087
        # need floats to exercise https://github.com/pytorch/pytorch/issues/16018
3088
        for bias in [True, False]:
3089
            conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(
3090
                device, torch.float
3091
            )
3092

3093
            input_nc = torch.randn(
3094
                (1, 3, 224, 224, 2), device=device, dtype=torch.float
3095
            )[:, :, :, :, 1]
3096
            input_c = input_nc.contiguous()
3097

3098
            weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[
3099
                :, :, :, :, 1
3100
            ]
3101
            conv1.weight = nn.Parameter(weight_nc)
3102
            weight_c = conv1.weight.contiguous()
3103

3104
            if bias:
3105
                bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
3106
                conv1.bias = nn.Parameter(bias_nc)
3107
                bias_c = conv1.bias.contiguous()
3108

3109
            out1 = conv1(input_nc)
3110
            conv1.weight = nn.Parameter(weight_c)
3111
            if bias:
3112
                conv1.bias = nn.Parameter(bias_c)
3113
            out2 = conv1(input_c)
3114
            self.assertEqual(out1, out2)
3115

3116
    @onlyCUDA
3117
    @largeTensorTest("12GB")
3118
    @skipIfRocmVersionLessThan((6, 0))
3119
    def test_conv_transposed_large(self, device):
3120
        dtype = torch.half if self.device_type == "cuda" else torch.float
3121
        conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
3122
        input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
3123
        # forward
3124
        ret = conv(input_large)
3125
        maxdiff0 = (
3126
            (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024)))
3127
            .abs_()
3128
            .max()
3129
            .item()
3130
        )
3131
        maxdiff1 = (
3132
            (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024)))
3133
            .abs_()
3134
            .max()
3135
            .item()
3136
        )
3137
        maxdiff2 = (
3138
            (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024)))
3139
            .abs_()
3140
            .max()
3141
            .item()
3142
        )
3143
        maxdiff3 = (
3144
            (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024)))
3145
            .abs_()
3146
            .max()
3147
            .item()
3148
        )
3149
        if self.device_type == "cuda":
3150
            # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0
3151
            self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5)
3152
            self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5)
3153
            self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5)
3154
            self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5)
3155
        else:
3156
            self.assertEqual(maxdiff0, 0)
3157
            self.assertEqual(maxdiff1, 0)
3158
            self.assertEqual(maxdiff2, 0)
3159
            self.assertEqual(maxdiff3, 0)
3160

3161
    @onlyCUDA
3162
    @skipCUDAIfRocm
3163
    @largeTensorTest("12GB")
3164
    def test_conv_large(self, device):
3165
        dtype = torch.half if self.device_type == "cuda" else torch.float
3166
        conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
3167
        input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
3168
        # forward
3169
        ret = conv(input_large)
3170
        self.assertEqual(ret[:2048], conv(input_large[:2048]))
3171
        self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
3172
        self.assertEqual(ret[4096:], conv(input_large[4096:]))
3173

3174
        # backward
3175
        conv.zero_grad()
3176
        # When computing the backward, we are using the `max(dim=1)`` to create
3177
        # some sparsity. Without this sparsity, the rounding error would be
3178
        # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual`
3179
        ret.view(4097, -1).max(dim=1).values.sum().backward()
3180
        del ret
3181
        grad1 = conv.weight.grad.detach().clone()
3182
        conv.zero_grad()
3183
        conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
3184
        conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
3185
        conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
3186
        grad2 = conv.weight.grad.detach().clone()
3187
        # gradients are at the order of hundreds, we need to scale it to
3188
        # the order of one so that we can compare
3189
        scale = 1 / grad2.abs().mean()
3190
        grad1 = grad1 * scale
3191
        grad2 = grad2 * scale
3192
        self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
3193

3194
    @onlyCUDA
3195
    @skipCUDAIfRocm
3196
    @largeTensorTest("20GB", "cpu")
3197
    @largeTensorTest("60GB", "cuda")
3198
    def test_conv_large_batch_1(self, device):
3199
        in_channels = 514
3200
        dim = 2048
3201
        out_channels = 1
3202
        kernel_size = 3
3203
        stride = 1
3204
        padding = 1
3205

3206
        input_tensor = torch.ones(1, in_channels, dim, dim).cuda().half()
3207
        model = (
3208
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
3209
            .cuda()
3210
            .half()
3211
        )
3212
        output = model(input_tensor)
3213
        model_cpu = model.cpu().float()
3214
        output_cpu = model(input_tensor.float().cpu())
3215
        self.assertEqual(output.cpu().float(), output_cpu, atol=1e-3, rtol=1e-3)
3216

3217
    @onlyCUDA
3218
    @skipCUDAIfRocm
3219
    @largeTensorTest("24GB", "cpu")
3220
    @largeTensorTest("20GB", "cuda")
3221
    def test_conv3d_large_batch_1(self, device):
3222
        x = torch.rand(1, 32, 512, 512, 256)
3223
        m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
3224
        yref = m(x)
3225
        y = m.to(device=device)(x.to(device=device))
3226
        self.assertEqual(yref, y.cpu())
3227

3228
    @onlyCUDA
3229
    @skipCUDAIfNoCudnn
3230
    def test_contig_wrong_stride_cudnn(self, device):
3231
        # x has to have batch_size 1 to test contiguous checks
3232
        x = torch.randn(1, 16, 5, 5, device=device)
3233
        stride = list(x.stride())
3234
        stride[0] = 20
3235
        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
3236
        x.set_(x.storage(), 0, x.size(), stride)
3237
        self.assertTrue(x.is_contiguous())
3238
        F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device))
3239
        F.conv2d(x, torch.randn(1, 16, 1, 1, device=device))
3240

3241
    @onlyCUDA
3242
    @tf32_on_and_off(0.005)
3243
    def test_Conv2d_size_1_kernel(self, device):
3244
        x_cpu = torch.randn(2, 3, 5, 5)
3245
        conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
3246
        y_cpu = conv_cpu(x_cpu)
3247
        y = torch.rand_like(y_cpu)
3248
        y_cpu.backward(y)
3249

3250
        with cudnn.flags(enabled=False):
3251
            conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
3252
            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
3253
            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
3254
            y_cuda = conv_cuda(x_cpu.to(device))
3255
            y_cuda.backward(y.to(device))
3256

3257
        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
3258
        self.assertEqual(
3259
            conv_cpu.bias.grad.data,
3260
            conv_cuda.bias.grad.data,
3261
            atol=1e-5,
3262
            rtol=0,
3263
            exact_device=False,
3264
        )
3265
        self.assertEqual(
3266
            conv_cpu.weight.grad.data,
3267
            conv_cuda.weight.grad.data,
3268
            atol=1e-5,
3269
            rtol=0,
3270
            exact_device=False,
3271
        )
3272

3273
    @onlyCUDA
3274
    @tf32_on_and_off(0.005)
3275
    def test_ConvTranspose2d_size_1_kernel(self, device):
3276
        x_cpu = torch.randn(2, 3, 5, 5)
3277
        conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
3278
        y_cpu = conv_cpu(x_cpu)
3279
        y = torch.rand_like(y_cpu)
3280
        y_cpu.backward(y)
3281

3282
        with cudnn.flags(enabled=False):
3283
            conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
3284
            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
3285
            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
3286
            y_cuda = conv_cuda(x_cpu.to(device))
3287
            y_cuda.backward(y.to(device))
3288

3289
        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
3290
        self.assertEqual(
3291
            conv_cpu.bias.grad.data,
3292
            conv_cuda.bias.grad.data,
3293
            atol=1e-5,
3294
            rtol=0,
3295
            exact_device=False,
3296
        )
3297
        self.assertEqual(
3298
            conv_cpu.weight.grad.data,
3299
            conv_cuda.weight.grad.data,
3300
            atol=1e-5,
3301
            rtol=0,
3302
            exact_device=False,
3303
        )
3304

3305
    @onlyCUDA
3306
    def test_ConvTranspose3d_size_1_kernel(self, device):
3307
        with set_default_dtype(torch.double):
3308
            x_cpu = torch.randn(2, 3, 3, 5, 5)
3309
            conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
3310
            y_cpu = conv_cpu(x_cpu)
3311
            y = torch.rand_like(y_cpu)
3312
            y_cpu.backward(y)
3313

3314
            with cudnn.flags(enabled=False):
3315
                conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
3316
                conv_cuda.bias.data.copy_(conv_cpu.bias.data)
3317
                conv_cuda.weight.data.copy_(conv_cpu.weight.data)
3318
                y_cuda = conv_cuda(x_cpu.to(device))
3319
                y_cuda.backward(y.to(device))
3320

3321
            self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
3322
            self.assertEqual(
3323
                conv_cpu.bias.grad.data,
3324
                conv_cuda.bias.grad.data,
3325
                atol=1e-5,
3326
                rtol=0,
3327
                exact_device=False,
3328
            )
3329
            self.assertEqual(
3330
                conv_cpu.weight.grad.data,
3331
                conv_cuda.weight.grad.data,
3332
                atol=1e-5,
3333
                rtol=0,
3334
                exact_device=False,
3335
            )
3336

3337
    @dtypesIfCUDA(
3338
        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
3339
    )
3340
    @dtypes(torch.float)
3341
    @torch.backends.cudnn.flags(enabled=True, benchmark=False)
3342
    @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
3343
    def test_Conv2d_naive_groups(self, device, dtype):
3344
        # Check that grouped convolutions matches two half convolutions
3345
        m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
3346
        i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
3347
        output = m(i)
3348
        grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
3349
        output.backward(grad_output)
3350

3351
        m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
3352
        m1.weight.data.copy_(m.weight.data[:2])
3353
        m1.bias.data.copy_(m.bias.data[:2])
3354
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
3355
        output1 = m1(i1)
3356
        output1.backward(grad_output[:, :2].contiguous())
3357

3358
        m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
3359
        m2.weight.data.copy_(m.weight.data[2:])
3360
        m2.bias.data.copy_(m.bias.data[2:])
3361
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
3362
        output2 = m2(i2)
3363
        output2.backward(grad_output[:, 2:].contiguous())
3364

3365
        self.assertEqual(output, torch.cat([output1, output2], 1))
3366
        self.assertEqual(
3367
            i.grad.data,
3368
            torch.cat([i1.grad.data, i2.grad.data], 1),
3369
            atol=dtype2prec_DONTUSE[dtype],
3370
            rtol=0,
3371
        )
3372
        self.assertEqual(
3373
            m.bias.grad.data,
3374
            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
3375
            atol=dtype2prec_DONTUSE[dtype],
3376
            rtol=0,
3377
        )
3378
        self.assertEqual(
3379
            m.weight.grad.data,
3380
            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
3381
            atol=dtype2prec_DONTUSE[dtype],
3382
            rtol=0,
3383
        )
3384

3385
    @dtypes(torch.double, torch.cdouble)
3386
    def test_Conv2d_backward_depthwise(self, device, dtype):
3387
        x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
3388
        weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)
3389

3390
        def conv2d_depthwise(x, weight):
3391
            return torch.nn.functional.conv2d(
3392
                x, weight, bias=None, stride=(1, 10), groups=2
3393
            )
3394

3395
        for cudnn_enabled in [False, True]:
3396
            with torch.backends.cudnn.flags(enabled=cudnn_enabled):
3397
                torch.autograd.gradcheck(conv2d_depthwise, (x, weight))
3398

3399
    @onlyCPU
3400
    @dtypes(torch.float, torch.double)
3401
    def test_conv_thnn_nhwc(self, device, dtype):
3402
        def helper(
3403
            mod,
3404
            n,
3405
            c,
3406
            h,
3407
            w,
3408
            out_channels,
3409
            kernel_size,
3410
            dilation,
3411
            groups,
3412
            input_format,
3413
            weight_format,
3414
        ):
3415
            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3416
                memory_format=input_format
3417
            )
3418
            input.requires_grad_()
3419
            conv = mod(
3420
                c, out_channels, kernel_size, dilation=dilation, groups=groups
3421
            ).to(device="cpu", dtype=dtype, memory_format=weight_format)
3422
            for p in conv.parameters():
3423
                p.data = torch.randint_like(p, -3, 3)
3424

3425
            ref_input = input.detach().clone().contiguous().requires_grad_()
3426
            ref_conv = mod(
3427
                c, out_channels, kernel_size, dilation=dilation, groups=groups
3428
            )
3429
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3430
            ref_conv.load_state_dict(conv.state_dict())
3431
            ref_conv = ref_conv.to(
3432
                device="cpu", dtype=dtype, memory_format=torch.contiguous_format
3433
            )
3434

3435
            out = conv(input)
3436
            ref_out = ref_conv(ref_input)
3437

3438
            grad = torch.randint_like(out, -3, 3)
3439
            ref_grad = grad.detach().clone().contiguous()
3440

3441
            out.backward(grad)
3442
            ref_out.backward(ref_grad)
3443

3444
            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
3445
            self.assertTrue(ref_out.is_contiguous())
3446
            self.assertEqual(out, ref_out, exact_dtype=False)
3447
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
3448
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
3449
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
3450

3451
        with torch.backends.mkldnn.flags(enabled=False):
3452
            formats = [
3453
                [torch.channels_last, torch.channels_last],
3454
                [torch.channels_last, torch.contiguous_format],
3455
                [torch.contiguous_format, torch.channels_last],
3456
            ]
3457
            for input_format, weight_format in formats:
3458
                # non-dilated conv: thnn_conv2d normal path (with im2col)
3459
                helper(
3460
                    nn.Conv2d,
3461
                    2,
3462
                    8,
3463
                    4,
3464
                    4,
3465
                    out_channels=4,
3466
                    kernel_size=3,
3467
                    dilation=1,
3468
                    groups=1,
3469
                    input_format=input_format,
3470
                    weight_format=weight_format,
3471
                )
3472
                helper(
3473
                    nn.Conv2d,
3474
                    2,
3475
                    8,
3476
                    4,
3477
                    4,
3478
                    out_channels=8,
3479
                    kernel_size=3,
3480
                    dilation=1,
3481
                    groups=8,
3482
                    input_format=input_format,
3483
                    weight_format=weight_format,
3484
                )
3485
                # test when input chanels is 1 and not converted to channels last
3486
                helper(
3487
                    nn.Conv2d,
3488
                    2,
3489
                    1,
3490
                    10,
3491
                    10,
3492
                    out_channels=8,
3493
                    kernel_size=3,
3494
                    dilation=1,
3495
                    groups=1,
3496
                    input_format=torch.contiguous_format,
3497
                    weight_format=torch.channels_last,
3498
                )
3499
                # non-dilated conv: thnn_conv2d fast path (skip im2col)
3500
                helper(
3501
                    nn.Conv2d,
3502
                    1,
3503
                    16,
3504
                    56,
3505
                    56,
3506
                    out_channels=16,
3507
                    kernel_size=1,
3508
                    dilation=1,
3509
                    groups=1,
3510
                    input_format=input_format,
3511
                    weight_format=weight_format,
3512
                )
3513
                # ic == oc == 1 here, so need to stick input to CL to activate channels last
3514
                helper(
3515
                    nn.Conv2d,
3516
                    1,
3517
                    16,
3518
                    56,
3519
                    56,
3520
                    out_channels=16,
3521
                    kernel_size=1,
3522
                    dilation=1,
3523
                    groups=16,
3524
                    input_format=torch.channels_last,
3525
                    weight_format=weight_format,
3526
                )
3527
                # dilated conv: slow_conv_dilated2d
3528
                helper(
3529
                    nn.Conv2d,
3530
                    2,
3531
                    8,
3532
                    11,
3533
                    13,
3534
                    out_channels=16,
3535
                    kernel_size=3,
3536
                    dilation=2,
3537
                    groups=1,
3538
                    input_format=input_format,
3539
                    weight_format=weight_format,
3540
                )
3541
                helper(
3542
                    nn.Conv2d,
3543
                    2,
3544
                    16,
3545
                    11,
3546
                    13,
3547
                    out_channels=32,
3548
                    kernel_size=3,
3549
                    dilation=2,
3550
                    groups=16,
3551
                    input_format=input_format,
3552
                    weight_format=weight_format,
3553
                )
3554
                # transposed-conv: slow_conv_transpose2d
3555
                helper(
3556
                    nn.ConvTranspose2d,
3557
                    2,
3558
                    8,
3559
                    4,
3560
                    4,
3561
                    out_channels=4,
3562
                    kernel_size=3,
3563
                    dilation=1,
3564
                    groups=1,
3565
                    input_format=input_format,
3566
                    weight_format=weight_format,
3567
                )
3568
                helper(
3569
                    nn.ConvTranspose2d,
3570
                    2,
3571
                    8,
3572
                    4,
3573
                    4,
3574
                    out_channels=8,
3575
                    kernel_size=3,
3576
                    dilation=1,
3577
                    groups=8,
3578
                    input_format=input_format,
3579
                    weight_format=weight_format,
3580
                )
3581
                helper(
3582
                    nn.ConvTranspose2d,
3583
                    1,
3584
                    16,
3585
                    56,
3586
                    56,
3587
                    out_channels=16,
3588
                    kernel_size=1,
3589
                    dilation=1,
3590
                    groups=1,
3591
                    input_format=input_format,
3592
                    weight_format=weight_format,
3593
                )
3594
                helper(
3595
                    nn.ConvTranspose2d,
3596
                    1,
3597
                    16,
3598
                    56,
3599
                    56,
3600
                    out_channels=32,
3601
                    kernel_size=1,
3602
                    dilation=1,
3603
                    groups=16,
3604
                    input_format=input_format,
3605
                    weight_format=weight_format,
3606
                )
3607

3608
    @onlyCUDA
3609
    @skipCUDAIfRocmVersionLessThan((4, 3))
3610
    @skipCUDAIfNotMiopenSuggestNHWC
3611
    @skipCUDAIfCudnnVersionLessThan(7603)
3612
    @dtypes(torch.half, torch.float, torch.cfloat)
3613
    def test_conv_cudnn_nhwc(self, device, dtype):
3614
        def helper(n, c, h, w, out_channels, kernel_size, groups):
3615
            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3616
                memory_format=torch.channels_last
3617
            )
3618
            input.requires_grad_()
3619
            conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
3620
                device="cuda", dtype=dtype, memory_format=torch.channels_last
3621
            )
3622
            for p in conv.parameters():
3623
                p.data = torch.randint_like(p, -3, 3)
3624

3625
            # use FP64 channels-first conv as reference
3626
            ref_input = input.detach().clone().contiguous().double().requires_grad_()
3627
            ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
3628
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3629
            ref_conv.load_state_dict(conv.state_dict())
3630
            ref_conv = ref_conv.to(
3631
                device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
3632
            )
3633

3634
            out = conv(input)
3635
            ref_out = ref_conv(ref_input)
3636

3637
            grad = torch.randint_like(out, -3, 3)
3638
            ref_grad = grad.detach().clone().double().contiguous()
3639

3640
            out.backward(grad)
3641
            ref_out.backward(ref_grad)
3642

3643
            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
3644
            self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
3645
            self.assertTrue(
3646
                conv.weight.grad.is_contiguous(memory_format=torch.channels_last)
3647
            )
3648

3649
            self.assertTrue(ref_out.is_contiguous())
3650
            self.assertTrue(ref_input.grad.is_contiguous())
3651
            self.assertTrue(ref_conv.weight.grad.is_contiguous())
3652

3653
            self.assertEqual(out, ref_out, exact_dtype=False)
3654
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
3655
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
3656
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
3657

3658
        helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
3659
        helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
3660
        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
3661
        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
3662

3663
    @onlyCUDA
3664
    @skipCUDAIfRocm
3665
    @skipCUDAIfCudnnVersionLessThan(8005)
3666
    @dtypes(torch.half, torch.float)
3667
    def test_conv_cudnn_ndhwc(self, device, dtype):
3668
        def helper(n, c, d, h, w, out_channels, kernel_size, groups):
3669
            input = torch.randint(
3670
                -2, 2, (n, c, d, h, w), dtype=dtype, device=device
3671
            ).to(memory_format=torch.channels_last_3d)
3672
            input.requires_grad_()
3673
            conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to(
3674
                device="cuda", dtype=dtype, memory_format=torch.channels_last_3d
3675
            )
3676
            for p in conv.parameters():
3677
                p.data = torch.randint_like(p, -2, 2)
3678

3679
            # use FP64 channels-first conv as reference
3680
            ref_input = input.detach().clone().contiguous().double().requires_grad_()
3681
            ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
3682
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3683
            ref_conv.load_state_dict(conv.state_dict())
3684
            ref_conv = ref_conv.to(
3685
                device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
3686
            )
3687

3688
            out = conv(input)
3689
            ref_out = ref_conv(ref_input)
3690

3691
            grad = torch.randint_like(out, -2, 2)
3692
            ref_grad = grad.detach().clone().double().contiguous()
3693

3694
            out.backward(grad)
3695
            ref_out.backward(ref_grad)
3696

3697
            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
3698
            self.assertTrue(
3699
                input.grad.is_contiguous(memory_format=torch.channels_last_3d)
3700
            )
3701
            self.assertTrue(
3702
                conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)
3703
            )
3704

3705
            self.assertTrue(ref_out.is_contiguous())
3706
            self.assertTrue(ref_input.grad.is_contiguous())
3707
            self.assertTrue(ref_conv.weight.grad.is_contiguous())
3708

3709
            self.assertEqual(out, ref_out, exact_dtype=False)
3710
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
3711
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
3712
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
3713

3714
        helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
3715
        helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
3716
        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
3717
        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)
3718

3719
    def _run_conv(
3720
        self,
3721
        layer,
3722
        device,
3723
        inp,
3724
        grad,
3725
        ref_conv,
3726
        ref_input,
3727
        ref_out,
3728
        input_format,
3729
        weight_format,
3730
        grad_format,
3731
        output_format,
3732
    ):
3733
        conv = (
3734
            layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device)
3735
        )
3736
        # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
3737
        conv.load_state_dict(ref_conv.state_dict())
3738
        weight_data = (
3739
            conv.weight.detach().clone().contiguous(memory_format=weight_format)
3740
        )
3741
        conv.weight.data = weight_data.resize_(
3742
            weight_data.size(), memory_format=weight_format
3743
        )
3744
        input = inp.clone().contiguous(memory_format=input_format)
3745
        input.resize_(input.size(), memory_format=input_format)
3746
        input = input.requires_grad_()
3747
        grad = grad.contiguous(memory_format=grad_format)
3748
        grad.resize_(grad.size(), memory_format=grad_format)
3749
        out = conv(input)
3750
        out.backward(grad)
3751
        self.assertTrue(out.is_contiguous(memory_format=output_format))
3752
        self.assertEqual(out, ref_out)
3753
        self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
3754
        self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
3755
        self.assertEqual(input.grad, ref_input.grad)
3756

3757
    def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
3758
        data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
3759
        ref_input = data.clone().contiguous().requires_grad_(True)
3760
        ref_conv = layer(c, k, filter_size).float().to(device)
3761
        ref_out = ref_conv(ref_input)
3762
        grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda")
3763
        ref_out.backward(grad)
3764

3765
        for w_f in [torch.contiguous_format, torch.channels_last]:
3766
            for g_f in [torch.contiguous_format, torch.channels_last]:
3767
                for input_format in [torch.contiguous_format, torch.channels_last]:
3768
                    output_format = torch.contiguous_format
3769
                    # Older versions of CudNN have Channels Last support disabled
3770
                    if torch.backends.cudnn.version() >= 7603:
3771
                        if input_format == torch.channels_last:
3772
                            output_format = torch.channels_last
3773
                        # This is because we have N111 weight that cannot handle
3774
                        # the ambiguous memory_format
3775
                        if w_f == torch.channels_last:
3776
                            if layer == nn.Conv2d and filter_size * c != 1:
3777
                                output_format = torch.channels_last
3778
                            if layer == nn.ConvTranspose2d and filter_size * k != 1:
3779
                                output_format = torch.channels_last
3780
                    self._run_conv(
3781
                        layer,
3782
                        device,
3783
                        data,
3784
                        grad,
3785
                        ref_conv,
3786
                        ref_input,
3787
                        ref_out,
3788
                        input_format,
3789
                        w_f,
3790
                        g_f,
3791
                        output_format,
3792
                    )
3793

3794
    @onlyCUDA
3795
    @skipCUDAIfRocmVersionLessThan((4, 3))
3796
    @skipCUDAIfNotMiopenSuggestNHWC
3797
    @skipCUDAIfCudnnVersionLessThan(7603)
3798
    @tf32_on_and_off(0.05)
3799
    def test_conv_cudnn_mismatch_memory_format(self, device):
3800
        configs = [
3801
            [4, 2, 8, 8, 4, 2],
3802
            [4, 1, 8, 8, 4, 2],
3803
            [1, 1, 8, 8, 4, 2],
3804
            [4, 2, 2, 8, 4, 1],
3805
            [4, 2, 1, 8, 4, 1],
3806
            [4, 2, 8, 8, 4, 1],
3807
            [4, 1, 8, 8, 4, 1],
3808
        ]
3809
        for n, c, h, w, k, filter_size in configs:
3810
            self._test_conv_cudnn_nhwc_nchw(
3811
                nn.Conv2d, n, c, h, w, k, filter_size, device
3812
            )
3813
            self._test_conv_cudnn_nhwc_nchw(
3814
                nn.ConvTranspose2d, n, c, h, w, k, filter_size, device
3815
            )
3816

3817
    # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
3818
    # returning CUDNN_STATUS_BAD_PARAM
3819
    # Disabling that specific test for now [see issue # 33918]
3820
    @onlyCUDA
3821
    @skipCUDAIfNoCudnn
3822
    @dtypes(torch.float, torch.double)
3823
    def test_conv_cudnn_nhwc_support(self, device, dtype):
3824
        input = torch.randn(
3825
            (1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True
3826
        )
3827
        weight = torch.randn(
3828
            (8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True
3829
        )
3830
        weight = weight.to(memory_format=torch.channels_last)
3831
        o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
3832
        self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
3833
        o.sum().backward()
3834

3835
    # Test that faster algorithms used for inference produce the same results
3836
    # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176
3837
    @onlyCPU
3838
    @dtypes(torch.float)
3839
    def test_conv2d_no_grad(self, device, dtype):
3840
        for batch in [1, 2, 3]:
3841
            for groups in [1, 2, 4]:
3842
                input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
3843
                m = nn.Conv2d(
3844
                    groups,
3845
                    8,
3846
                    kernel_size=(3, 3),
3847
                    groups=groups,
3848
                    dtype=dtype,
3849
                    device=device,
3850
                )
3851
                with torch.no_grad():
3852
                    output_ng = m(input)
3853
                output = m(input)
3854
                self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)
3855

3856
    @onlyCUDA
3857
    @skipCUDAIfNoCudnn
3858
    @dtypes(torch.float, torch.float16)
3859
    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
3860
    def test_cudnn_convolution_relu(self, device, dtype):
3861
        for batch, groups, image_size, kernel_size, memory_format in product(
3862
            (1, 2, 3),
3863
            (1, 2, 4),
3864
            ((1, 1), (8, 8)),
3865
            ((1, 1), (3, 3)),
3866
            (torch.channels_last, torch.contiguous_format),
3867
        ):
3868
            if image_size[0] < kernel_size[0]:
3869
                continue
3870
            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
3871
            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
3872
            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
3873
            inp = inp.to(memory_format=memory_format)
3874
            w = w.to(memory_format=memory_format)
3875
            if torch.version.hip:
3876
                cudnn_out = torch.miopen_convolution_relu(
3877
                    inp, w, None, (1, 1), (0, 0), (1, 1), 1
3878
                )
3879
            else:
3880
                cudnn_out = torch.cudnn_convolution_relu(
3881
                    inp, w, None, (1, 1), (0, 0), (1, 1), 1
3882
                )
3883
            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
3884
            if tf32_is_not_fp32() and dtype == torch.float:
3885
                self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
3886
            else:
3887
                self.assertEqual(conv2d_out.relu(), cudnn_out)
3888

3889
    @onlyCUDA
3890
    @skipCUDAIfNoCudnn
3891
    @dtypes(torch.float, torch.float16)
3892
    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
3893
    def test_cudnn_convolution_add_relu(self, device, dtype):
3894
        for batch, groups, image_size, kernel_size, memory_format in product(
3895
            (1, 2, 3),
3896
            (1, 2, 4),
3897
            ((1, 1), (8, 8)),
3898
            ((1, 1), (3, 3)),
3899
            (torch.channels_last, torch.contiguous_format),
3900
        ):
3901
            if image_size[0] < kernel_size[0]:
3902
                continue
3903
            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
3904
            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
3905
            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
3906
            alpha = 2.0
3907
            z = torch.randn_like(conv2d_out)
3908

3909
            inp = inp.to(memory_format=memory_format)
3910
            w = w.to(memory_format=memory_format)
3911
            z = z.to(memory_format=memory_format)
3912
            if torch.version.hip:
3913
                cudnn_out = torch.miopen_convolution_add_relu(
3914
                    inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
3915
                )
3916
            else:
3917
                cudnn_out = torch.cudnn_convolution_add_relu(
3918
                    inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
3919
                )
3920

3921
            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
3922
            if tf32_is_not_fp32() and dtype == torch.float:
3923
                self.assertEqual(
3924
                    F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006
3925
                )
3926
            else:
3927
                self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
3928

3929
    @onlyCUDA
3930
    @skipCUDAIfRocm
3931
    @skipCUDAIfCudnnVersionLessThan(7603)
3932
    def test_convert_conv2d_weight_memory_format(self, device):
3933
        input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
3934
        model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
3935
        for memory_format in [torch.channels_last, torch.contiguous_format]:
3936
            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
3937
            out = model(input)
3938
            self.assertTrue(out.is_contiguous(memory_format=memory_format))
3939

3940
        model = (
3941
            nn.Sequential(nn.ConvTranspose2d(8, 4, 3), nn.BatchNorm2d(4))
3942
            .to(device)
3943
            .float()
3944
        )
3945
        for memory_format in [torch.channels_last, torch.contiguous_format]:
3946
            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
3947
            out = model(input)
3948
            self.assertTrue(out.is_contiguous(memory_format=memory_format))
3949

3950
    @onlyCUDA
3951
    @skipCUDAIfRocm
3952
    @skipCUDAIfCudnnVersionLessThan(7603)
3953
    def test_convert_conv3d_weight_memory_format(self, device):
3954
        input = torch.randint(
3955
            1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device
3956
        )
3957
        model = (
3958
            nn.Sequential(nn.ConvTranspose3d(8, 4, 3), nn.BatchNorm3d(4))
3959
            .to(device)
3960
            .float()
3961
        )
3962
        for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
3963
            model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
3964
            out = model(input)
3965
            self.assertTrue(out.is_contiguous(memory_format=memory_format))
3966

3967
    def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
3968
        # Test that _convolution_double_backward() outputs the correct grad shapes
3969
        # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
3970
        # specific case that was uncovered during the convolution consolidation effort.
3971
        # The test can be safely deleted if _convolution_double_backward() is removed.
3972

3973
        input = torch.randn(2, 3, 6, device=device)
3974
        weight = torch.randn(3, 3, 3, device=device)
3975
        bias = torch.randn(3, device=device)
3976
        stride = (2,)
3977
        padding = (1,)
3978
        dilation = (1,)
3979
        transposed = False
3980
        output_padding = (0,)
3981
        groups = 1
3982
        output = torch.ops.aten.convolution(
3983
            input,
3984
            weight,
3985
            bias,
3986
            stride,
3987
            padding,
3988
            dilation,
3989
            transposed,
3990
            output_padding,
3991
            groups,
3992
        )
3993

3994
        ggI = torch.randn(input.shape, device=device)
3995
        ggW = torch.randn(weight.shape, device=device)
3996
        ggB = torch.randn(bias.shape, device=device)
3997
        gO = torch.randn(output.shape, device=device)
3998
        output_mask = [True, True, True]
3999
        (
4000
            grad_grad_output,
4001
            grad_input,
4002
            grad_weight,
4003
        ) = torch.ops.aten._convolution_double_backward(
4004
            ggI,
4005
            ggW,
4006
            ggB,
4007
            gO,
4008
            weight,
4009
            input,
4010
            stride,
4011
            padding,
4012
            dilation,
4013
            transposed,
4014
            output_padding,
4015
            groups,
4016
            output_mask,
4017
        )
4018

4019
        # Make sure the correct shapes are computed.
4020
        self.assertEqual(grad_grad_output.shape, gO.shape)
4021
        self.assertEqual(grad_input.shape, input.shape)
4022
        self.assertEqual(grad_weight.shape, weight.shape)
4023

4024
    @onlyCUDA
4025
    @largeTensorTest("40GB")
4026
    @largeTensorTest("24GB", "cpu")
4027
    def test_conv3d_64bit_indexing(self, device):
4028
        x = torch.rand(1, 32, 512, 512, 256)
4029
        m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
4030
        yref = m(x)
4031
        y = m.to(device=device)(x.to(device=device))
4032
        self.assertEqual(yref, y)
4033

4034

4035
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals())
4036
instantiate_parametrized_tests(TestConvolutionNN)
4037

4038
if __name__ == "__main__":
4039
    run_tests()
4040

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.