pytorch

Форк
0
/
test_mkldnn.py 
1592 строки · 69.1 Кб
1
# Owner(s): ["module: mkldnn"]
2

3
import copy
4
import itertools
5
import functools
6
import unittest
7
from contextlib import nullcontext
8

9
try:
10
    import torchvision
11
    HAS_TORCHVISION = True
12
except ImportError:
13
    HAS_TORCHVISION = False
14

15
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
16

17
import torch
18
import torch.nn.functional as F
19
import torch.jit
20
import torch.backends.mkldnn
21
from torch.utils import mkldnn as mkldnn_utils
22
from torch.testing._internal.common_utils import TestCase, \
23
    run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \
24
    skipIfTorchDynamo
25
from torch.testing._internal.common_device_type import (
26
    instantiate_device_type_tests,
27
    dtypes,
28
)
29

30
# batched grad doesn't support mkldnn
31
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
32
gradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False)
33

34

35
types = [torch.float, torch.bfloat16, torch.half]
36

37
# Comment the line below to find out the CI machines having MKL-DNN build disabled
38
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
39
class TestMkldnn(TestCase):
40
    def test_conversion(self):
41
        for cpu_tensor in [torch.randn((1, 2, 3, 4),
42
                                       dtype=torch.float, device=torch.device('cpu')),
43
                           torch.randn((1, 2, 3, 4, 5),
44
                                       dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]:
45
            cpu_tensor.requires_grad_()
46
            convert_dtypes = {torch.half: [torch.half, torch.float],
47
                              torch.bfloat16: [torch.bfloat16, torch.float],
48
                              torch.float: [torch.bfloat16, torch.half]}
49
            # float/bfloat16/half cpu tensor to mkldnn tensortensor.
50
            for dtype1 in types:
51
                mkldnn_tensor = cpu_tensor.to_mkldnn(dtype1)
52
                self.assertEqual(mkldnn_tensor.dtype, dtype1)
53
                cpu_tensor_1 = mkldnn_tensor.to_dense()
54
                # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
55
                self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
56
                # mkldnn float/bfloat tensor to cpu float or bfloat tensor
57
                for dtype2 in convert_dtypes[dtype1]:
58
                    cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
59
                    self.assertEqual(cpu_tensor_2.dtype, dtype2)
60
                    atol = 1e-5 if dtype1 == torch.float and dtype2 == torch.float else 1e-2
61
                    self.assertEqual(cpu_tensor, cpu_tensor_2.float(), atol=atol, rtol=0)
62

63
                self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
64
                self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
65
                self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
66
                if dtype1 == torch.float:
67
                    self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
68
                else:
69
                    self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size() / 2)
70
                self.assertRaisesRegex(RuntimeError,
71
                                       "Cannot access data pointer of Tensor that doesn't have storage",
72
                                       lambda: mkldnn_tensor.data_ptr() != 0)
73

74
            # bfloat cpu tensor to mkldnn float tensor or bfloat tensor.
75
            for orig_dtype in [torch.half, torch.bfloat16]:
76
                cpu_tensor_lower = cpu_tensor.to(dtype=orig_dtype)
77
                for dtype1 in convert_dtypes[orig_dtype]:
78
                    mkldnn_tensor = cpu_tensor_lower.to_mkldnn(dtype1)
79
                    self.assertEqual(mkldnn_tensor.dtype, dtype1)
80
                    cpu_tensor_1 = mkldnn_tensor.to_dense()
81
                    # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
82
                    self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
83
                    # mkldnn float/bfloat/half tensor to cpu float/bfloat/half tensor
84
                    for dtype2 in convert_dtypes[cpu_tensor_lower.dtype]:
85
                        cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
86
                        self.assertEqual(cpu_tensor_2.dtype, dtype2)
87
                        self.assertEqual(cpu_tensor_lower,
88
                                         cpu_tensor_2.to(dtype=cpu_tensor_lower.dtype), atol=1e-5, rtol=0)
89

90
                    self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
91
                    self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
92
                    self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
93
                    if dtype1 in [torch.bfloat16, torch.half]:
94
                        self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size())
95
                    else:
96
                        self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size() * 2)
97
                    self.assertRaisesRegex(RuntimeError,
98
                                           "Cannot access data pointer of Tensor that doesn't have storage",
99
                                           lambda: mkldnn_tensor.data_ptr() != 0)
100

101
    def test_conversion_byte_char(self):
102
        int8_types = [torch.int8, torch.uint8]
103
        for int8_type in int8_types:
104
            low = -100 if int8_type is torch.int8 else 0
105
            high = 100
106
            for cpu_tensor in [torch.randint(
107
                               low=low,
108
                               high=high,
109
                               size=(1, 2, 3, 4),
110
                               dtype=torch.int64,
111
                               device=torch.device('cpu')),
112
                               torch.randint(
113
                               low=low,
114
                               high=high,
115
                               size=(1, 2, 3, 4, 5),
116
                               dtype=torch.int64,
117
                               device=torch.device('cpu'))[:, :, :, :, :]]:
118

119
                cpu_tensor = cpu_tensor.to(dtype=int8_type)
120
                mkldnn_tensor = cpu_tensor.to_mkldnn(int8_type)
121
                self.assertEqual(mkldnn_tensor.dtype, int8_type)
122
                cpu_tensor_1 = mkldnn_tensor.to_dense()
123
                self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
124
                self.assertEqual(cpu_tensor, cpu_tensor_1)
125
                self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
126
                self.assertEqual(mkldnn_tensor.size(), cpu_tensor.size())
127
                self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
128
                self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
129
                self.assertRaisesRegex(RuntimeError,
130
                                       "Cannot access data pointer of Tensor that doesn't have storage",
131
                                       lambda: mkldnn_tensor.data_ptr() != 0)
132

133
    def test_copy(self):
134
        x = torch.randn(4, 5, dtype=torch.float32)
135
        mkldnn_x = x.to_mkldnn()
136
        mkldnn_y = torch.randn(4, 5, dtype=torch.float32).to_mkldnn()
137
        mkldnn_z = torch.randn(4, 10, dtype=torch.float32).to_mkldnn()
138
        mkldnn_y.copy_(mkldnn_x)
139
        self.assertEqual(x, mkldnn_y.to_dense())
140
        self.assertRaisesRegex(RuntimeError,
141
                               "copy_mkldnn_: only support same size tensor.",
142
                               lambda: mkldnn_z.copy_(mkldnn_x))
143
        self.assertRaisesRegex(RuntimeError,
144
                               "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
145
                               "Found self type = torch.FloatTensor and src type = Mkldnntorch.FloatTensor",
146
                               lambda: x.copy_(mkldnn_x))
147
        self.assertRaisesRegex(RuntimeError,
148
                               "copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
149
                               "Found self type = Mkldnntorch.FloatTensor and src type = torch.FloatTensor",
150
                               lambda: mkldnn_x.copy_(x))
151

152
    def test_unsupported(self):
153
        # unsupported types and unsupported types with gpu
154
        for dtype in [torch.double, torch.uint8, torch.int8,
155
                      torch.short, torch.int, torch.long]:
156
            with self.assertRaises(RuntimeError) as context:
157
                torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn()
158
            if torch.cuda.is_available():
159
                with self.assertRaises(RuntimeError) as context:
160
                    torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn()
161
        # supported type with gpu
162
        if torch.cuda.is_available():
163
            with self.assertRaises(RuntimeError) as context:
164
                torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn()
165
        # some factory functions
166
        for creator in [torch.ones, torch.randn, torch.rand]:
167
            with self.assertRaises(RuntimeError) as context:
168
                creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn)
169

170
    def test_mkldnn_conv_shapecheck(self):
171
        input = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
172
        w1 = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
173
        b1 = torch.full((1,), 1, dtype=torch.float32)
174
        w2 = torch.full((1, 1, 2, 24,), 1, dtype=torch.float32)
175
        b2 = torch.full((2,), 1, dtype=torch.float32)
176
        options = zip([-1, 0, 0, 0, 0, 0, 0],  # padding
177
                      [1, 0, 1, 1, 1, 1, 1],  # stride
178
                      [1, 1, 0, 1, 1, 1, 1],  # dilation
179
                      [1, 1, 1, 0, 2, 1, 1],  # groups
180
                      [w1, w1, w1, w1, w1, w1, w2],  # weight
181
                      [b1, b1, b1, b1, b1, b2, b1])  # bias
182
        for pad, st, dil, gr, w, b in options:
183
            with self.assertRaises(RuntimeError) as _:
184
                torch.mkldnn_convolution(input, w, b, [pad] * 2, [st] * 2, [dil] * 2, gr)
185

186
    def test_autograd_to_mkldnn(self):
187
        # MKLDNN only supports float32
188
        root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
189

190
        def func(root):
191
            return root.to_mkldnn().to_dense()
192

193
        # because MKLDNN only supports float32, we need to lessen the precision.
194
        # these numbers are just empirical results that seem to work.
195
        self.assertWarnsRegex(UserWarning,
196
                              'double precision floating point',
197
                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
198
        self.assertWarnsRegex(UserWarning,
199
                              'double precision floating point',
200
                              lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
201

202
    def test_autograd_from_mkldnn(self):
203
        # MKLDNN only supports float32
204
        root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
205

206
        def func(root):
207
            return root.to_dense()
208

209
        # because MKLDNN only supports float32, we need to lessen the precision.
210
        # these numbers are just empirical results that seem to work.
211
        self.assertWarnsRegex(UserWarning,
212
                              'double precision floating point',
213
                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
214

215
    def test_detach(self):
216
        root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
217

218
        detach = root.detach()
219
        self.assertEqual((4, 5), detach.size())
220
        self.assertFalse(detach.requires_grad)
221
        self.assertTrue(root.requires_grad)
222

223
        detach_ = root.detach_()
224
        self.assertEqual((4, 5), detach_.size())
225
        self.assertFalse(detach_.requires_grad)
226
        self.assertFalse(root.requires_grad)
227

228
    def test_repr(self):
229
        self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4),
230
                                                                  dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
231

232
    def _test_conv_base(self, dim):
233
        conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
234
        input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
235
        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
236
        for train, bias, dilation, groups in options:
237
            N = torch.randint(3, 10, (1,)).item()
238
            M = torch.randint(1, 3, (1,)).item() * groups
239
            C = torch.randint(1, 3, (1,)).item() * groups
240
            x_shape = (N, C) + input_shapes[dim]
241
            x = torch.randn(x_shape, dtype=torch.float32)
242
            conv = conv_module[dim](in_channels=C,
243
                                    out_channels=M,
244
                                    kernel_size=3,
245
                                    stride=2,
246
                                    padding=1,
247
                                    dilation=dilation,
248
                                    bias=bias,
249
                                    groups=groups).float()
250
            x1 = x.clone()
251
            x2 = x.clone().to_mkldnn()
252
            if not train:
253
                mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
254
            elif train and dim != 1:
255
                # TODO: enable conv1d training.
256
                x1.requires_grad_()
257
                x2.requires_grad_()
258
                mkldnn_conv = copy.deepcopy(conv)
259
            with torch.backends.mkldnn.flags(enabled=False):
260
                y_aten = conv(x1)
261
                if train and dim != 1:
262
                    loss1 = y_aten.sum()
263
                    loss1.backward()
264
            if not train or (train and dim != 1):
265
                y_mkldnn = mkldnn_conv(x2).to_dense()
266
                self.assertEqual(y_aten, y_mkldnn)
267
            if not train:
268
                self._test_serialization(mkldnn_conv, (x.to_mkldnn(),))
269
                self._test_tracing(mkldnn_conv, (x.to_mkldnn(),))
270
            elif dim != 1:
271
                loss2 = y_mkldnn.sum()
272
                loss2.backward()
273
                self.assertTrue(x2.grad.is_mkldnn)
274
                self.assertEqual(x1.grad, x2.grad.to_dense())
275
                self.assertEqual(conv.weight.grad,
276
                                 mkldnn_conv.weight.grad,
277
                                 atol=1e-3,
278
                                 rtol=1e-3)
279
                if bias:
280
                    self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad)
281

282
    def test_conv1d(self):
283
        self._test_conv_base(dim=1)
284

285
    def test_conv2d(self):
286
        self._test_conv_base(dim=2)
287

288
    def test_conv3d(self):
289
        self._test_conv_base(dim=3)
290

291
    def _test_conv_deconv_lower_precision_base(self, dim, conv_module, dtype):
292
        input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
293
        options = itertools.product([True, False], [1, 2], [1, 4])
294
        for bias, dilation, groups in options:
295
            N = torch.randint(1, 3, (1,)).item()
296
            M = torch.randint(1, 3, (1,)).item() * groups
297
            C = torch.randint(1, 3, (1,)).item() * groups
298
            x_shape = (N, C) + input_shapes[dim]
299
            x = torch.randn(x_shape, dtype=torch.float32)
300
            # TODO: remove this when group depthwise is supported:
301
            if conv_module in [torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
302
                               torch.nn.ConvTranspose3d] and groups > 1 and C == groups:
303
                continue
304
            conv = conv_module(in_channels=C,
305
                               out_channels=M,
306
                               kernel_size=3,
307
                               stride=2,
308
                               padding=1,
309
                               dilation=dilation,
310
                               bias=bias,
311
                               groups=groups).float()
312
            x_lower = x.to(dtype=dtype)
313
            if (dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported()) or \
314
               (dtype == torch.half and torch.ops.mkldnn._is_mkldnn_fp16_supported()):
315
                mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
316
                mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
317
                y = mkldnn_conv(x.to_mkldnn()).to_dense()
318
                y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
319
                self.assertEqual(y, y_lower, atol=1e-1, rtol=1e-3)
320
            else:
321
                msg = {
322
                    torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
323
                    torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
324
                }
325
                with self.assertRaisesRegex(RuntimeError, msg[dtype]):
326
                    mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
327
                    y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
328
            # test thnn impl
329
            conv_lower = copy.deepcopy(conv).to(dtype=dtype)
330
            conv_ref = copy.deepcopy(conv_lower).float()
331
            with torch.backends.mkldnn.flags(enabled=False):
332
                x_ref = x_lower.clone().float().detach().requires_grad_()
333
                x_lower.requires_grad_()
334
                y = conv_ref(x_ref)
335
                y_lower = conv_lower(x_lower).float()
336
                self.assertEqual(y, y_lower, atol=5e-2, rtol=5e-3)
337

338
    @dtypes(torch.float16, torch.bfloat16)
339
    def test_conv_deconv_1d_lower_precision(self, dtype):
340
        self._test_conv_deconv_lower_precision_base(1, torch.nn.Conv1d, dtype=dtype)
341
        self._test_conv_deconv_lower_precision_base(1, torch.nn.ConvTranspose1d, dtype=dtype)
342

343
    @dtypes(torch.float16, torch.bfloat16)
344
    def test_conv_deconv_2d_lower_precision(self, dtype):
345
        self._test_conv_deconv_lower_precision_base(2, torch.nn.Conv2d, dtype=dtype)
346
        self._test_conv_deconv_lower_precision_base(2, torch.nn.ConvTranspose2d, dtype=dtype)
347

348
    @dtypes(torch.float16, torch.bfloat16)
349
    def test_conv_deconv_3d_lower_precision(self, dtype):
350
        self._test_conv_deconv_lower_precision_base(3, torch.nn.Conv3d, dtype=dtype)
351
        self._test_conv_deconv_lower_precision_base(3, torch.nn.ConvTranspose3d, dtype=dtype)
352

353
    def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None):
354
        input_shapes = {2: (55, 55), 3: (14, 14, 14)}
355
        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
356
        if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
357
            cl_format = torch.channels_last
358
            input_shape = input_shapes[2]
359
        elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]:
360
            cl_format = torch.channels_last_3d
361
            input_shape = input_shapes[3]
362

363
        for train, bias, dilation, groups in options:
364
            N = torch.randint(3, 10, (1,)).item()
365
            M = torch.randint(1, 3, (1,)).item() * groups
366
            C = torch.randint(1, 3, (1,)).item() * groups
367
            x_shape = (N, C) + input_shape
368
            x = torch.randn(x_shape, dtype=dtype)
369

370
            # conv1: mkldnn conv/deconv in contiguous memory format (nchw)
371
            # conv2: mkldnn conv/deconv in channels last memory format (nhwc)
372
            conv1 = conv_module(in_channels=C,
373
                                out_channels=M,
374
                                kernel_size=3,
375
                                stride=2,
376
                                padding=1,
377
                                dilation=dilation,
378
                                bias=bias,
379
                                groups=groups).to(dtype=dtype)
380
            conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format)
381
            x1 = x.clone()
382
            x2 = x.clone().to(memory_format=cl_format)
383
            if train:
384
                x1.requires_grad_()
385
                x2.requires_grad_()
386
            y1 = conv1(x1)
387
            y2 = conv2(x2)
388
            self.assertEqual(y1, y2, atol=prec, rtol=prec)
389

390
            if train:
391
                y1.sum().backward()
392
                y2.sum().backward()
393
                self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format))
394
                self.assertEqual(conv1.weight.grad,
395
                                 conv2.weight.grad,
396
                                 atol=1e-3,
397
                                 rtol=1e-3)
398
                if bias:
399
                    self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
400
                self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
401

402
    def test_conv_nhwc_fp32(self):
403
        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
404
        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
405
        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32)
406
        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32)
407

408
    @dtypes(torch.float16, torch.bfloat16)
409
    def test_conv_nhwc_lower_precision(self, dtype):
410
        # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
411
        # returns false, bf16/fp16 CPU conv will fall back to thnn impl
412
        support_checks = {
413
            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
414
            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
415
        }
416
        if support_checks[dtype]():
417
            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype)
418
            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype)
419
            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype)
420
            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype)
421

422
        # BF16/FP16 fallback implementations are divided into two parts im2col+gemm,
423
        # and the number of data type conversions in the middle is more than that of onednn's direct conv,
424
        # resulting in additional accuracy loss.
425
        precisions = {
426
            torch.bfloat16: 1e-2,
427
            torch.float16: 2e-3,
428
        }
429
        prec = precisions[dtype]
430
        with torch.backends.mkldnn.flags(enabled=False):
431
            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype, prec=prec)
432
            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype, prec=prec)
433
            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype, prec=prec)
434
            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec)
435

436

437
    def test_conv_transpose_nhwc_fp32(self):
438
        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
439
        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
440
        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32)
441
        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32)
442

443
    @dtypes(torch.float16, torch.bfloat16)
444
    def test_conv_transpose_nhwc_lower_precision(self, dtype):
445
        # when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
446
        # returns false, bf16/fp16 CPU conv will fall back to thnn impl
447
        support_checks = {
448
            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
449
            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
450
        }
451
        if support_checks[dtype]():
452
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype)
453
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype)
454
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype)
455
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype)
456

457
        # BF16/FP16 fallback implementations are divided into two parts col2im+gemm,
458
        # and the number of data type conversions in the middle is more than that of onednn's direct conv,
459
        # resulting in additional accuracy loss.
460
        precisions = {
461
            torch.bfloat16: 2e-2,
462
            torch.float16: 3e-3,
463
        }
464
        prec = precisions[dtype]
465
        with torch.backends.mkldnn.flags(enabled=False):
466
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype, prec=prec)
467
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype, prec=prec)
468
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype, prec=prec)
469
            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype, prec=prec)
470

471
    def _test_conv_transpose_base(self, dim):
472
        conv_module = {
473
            1: torch.nn.ConvTranspose1d,
474
            2: torch.nn.ConvTranspose2d,
475
            3: torch.nn.ConvTranspose3d
476
        }
477
        input_shapes = {1: (55,), 2: (28, 28), 3: (14, 14, 14)}
478
        options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
479
        for train, bias, dilation, groups in options:
480
            N = torch.randint(3, 10, (1,)).item()
481
            M = torch.randint(1, 3, (1,)).item() * groups
482
            C = torch.randint(1, 3, (1,)).item() * groups
483
            x_shape = (N, C) + input_shapes[dim]
484
            data = torch.randn(x_shape, dtype=torch.float32)
485
            # conv: mkldnn tranpose conv fp32
486
            # conv_ref: thnn transpose conv fp32
487
            conv = conv_module[dim](in_channels=C,
488
                                    out_channels=M,
489
                                    kernel_size=3,
490
                                    stride=1,
491
                                    padding=1,
492
                                    dilation=dilation,
493
                                    bias=bias,
494
                                    groups=groups).to(dtype=torch.float32)
495
            x = data.clone()
496
            x_ref = x.clone()
497
            if train:
498
                x.requires_grad_()
499
                x_ref.requires_grad_()
500

501
            conv_ref = copy.deepcopy(conv)
502
            with torch.backends.mkldnn.flags(enabled=False):
503
                y_ref = conv_ref(x_ref)
504
                if train:
505
                    y_ref.sum().backward()
506

507
            y = conv(x)
508
            if train:
509
                y.sum().backward()
510

511
            self.assertEqual(y, y_ref)
512
            if train:
513
                self.assertEqual(x.grad, x_ref.grad)
514
                self.assertEqual(conv.weight.grad,
515
                                 conv_ref.weight.grad,
516
                                 atol=1e-3,
517
                                 rtol=1e-3)
518
                if bias:
519
                    self.assertEqual(conv.bias.grad, conv_ref.bias.grad)
520

521
    def test_conv_transpose1d(self):
522
        self._test_conv_transpose_base(dim=1)
523

524
    def test_conv_transpose2d(self):
525
        self._test_conv_transpose_base(dim=2)
526

527
    def test_conv_transpose3d(self):
528
        self._test_conv_transpose_base(dim=3)
529

530
    def test_conv2d_legacy_jit_model(self):
531
        """
532
        MKLDNN integration used to serialize models with 5d weight for grouped
533
        convolutions, we'd like to preserve this behavior
534
        """
535
        g = 4
536
        conv2d = torch.nn.Conv2d(16, 16, 3, groups=g)
537
        conv2d_mkldnn = torch.utils.mkldnn.to_mkldnn(conv2d)
538

539
        # contrive legacy conv2d module with a 5-d weight
540
        o, i, h, w = conv2d.weight.shape
541
        weight_5d = conv2d.weight.reshape((g, o // g, i, h, w))
542
        conv2d_mkldnn.weight = weight_5d.to_mkldnn()
543

544
        x = torch.randn(1, 16, 8, 8)
545

546
        with TemporaryFileName() as fname:
547
            torch.jit.save(conv2d_mkldnn, fname)
548
            conv2d_loaded = torch.jit.load(fname)
549

550
            self.assertEqual(conv2d_mkldnn.weight.ndimension(), 5)
551
            self.assertEqual(conv2d_loaded.weight.ndimension(), 4)
552
            self.assertEqual(
553
                conv2d(x),
554
                conv2d_loaded(x.to_mkldnn()).to_dense())
555

556
    # This test is to check whether 1D conv is supported for mkldnn tensor,
557
    # which is exposed by Issue https://github.com/pytorch/pytorch/issues/68034.
558
    def test_conv1d_functional(self):
559
        input = torch.randn(2, 3, 10).to_mkldnn()
560
        weight = torch.randn(3, 3, 3).to_mkldnn()
561
        bias = torch.randn(3).to_mkldnn()
562
        output = torch.nn.functional.conv1d(input, weight, bias)
563
        self.assertEqual(output.size(), torch.Size([2, 3, 8]))
564

565
    def test_relu(self):
566
        x = torch.randn((4, 5), dtype=torch.float32) * 10
567
        x1 = x.clone().requires_grad_()
568
        x2 = x.clone().to_mkldnn().requires_grad_()
569
        y1 = torch.relu(x1)
570
        y2 = torch.relu(x2).to_dense()
571
        loss1 = y1.sum()
572
        loss2 = y2.sum()
573
        loss1.backward()
574
        loss2.backward()
575
        self.assertEqual(y1, y2)
576
        self.assertEqual(x1.grad, x2.grad.to_dense())
577

578
    def test_relu_(self):
579
        x = torch.randn((4, 5), dtype=torch.float32) * 10
580
        x1 = x.clone().requires_grad_()
581
        x2 = x.clone().to_mkldnn().requires_grad_()
582
        y1 = torch.relu_(x1.clone())
583
        y2 = torch.relu_(x2.clone()).to_dense()
584
        loss1 = y1.sum()
585
        loss2 = y2.sum()
586
        loss1.backward()
587
        loss2.backward()
588
        self.assertEqual(y1, y2)
589
        self.assertEqual(x1.grad, x2.grad.to_dense())
590

591
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
592
    def _test_relu_bf16_base(self, name):
593
        x = torch.randn((4, 5), dtype=torch.float32) * 10
594
        x_bf16 = x.bfloat16()
595
        fn = getattr(torch, name)
596
        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
597
            y = fn(x.to_mkldnn()).to_dense()
598
            y_bf16 = fn(x_bf16.to_mkldnn()).to_dense(torch.float32)
599
            self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
600
        else:
601
            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
602
            self.assertRaisesRegex(RuntimeError,
603
                                   msg,
604
                                   lambda: fn(x_bf16.to_mkldnn()))
605

606
    def test_relu_bf16(self):
607
        self._test_relu_bf16_base("relu")
608

609
    def test_relu_inplace_bf16(self):
610
        self._test_relu_bf16_base("relu_")
611

612
    def test_gelu(self):
613
        m = torch.nn.GELU()
614
        x = torch.randn((4, 5), dtype=torch.float32) * 10
615
        x1 = x.clone().requires_grad_()
616
        x2 = x.clone().to_mkldnn().requires_grad_()
617
        y1 = m(x1)
618
        y2 = m(x2).to_dense()
619
        loss1 = y1.sum()
620
        loss2 = y2.sum()
621
        loss1.backward()
622
        loss2.backward()
623
        self.assertEqual(y1, y2)
624
        self.assertEqual(x1.grad, x2.grad.to_dense())
625

626
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
627
    def test_gelu_bf16(self):
628
        m = torch.nn.GELU()
629
        x = torch.randn((4, 5), dtype=torch.float32) * 10
630
        x1 = x.clone().to_mkldnn().requires_grad_()
631
        x2 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
632
        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
633
            y1 = m(x1).to_dense()
634
            y2 = m(x2).to_dense()
635
            loss1 = y1.sum()
636
            loss2 = y2.sum()
637
            loss1.backward()
638
            loss2.backward()
639
            self.assertEqual(y1, y2.to(torch.float32), atol=1e-1, rtol=0)
640
            self.assertEqual(x1.grad.to_dense(), x2.grad.to_dense(torch.float32), atol=1e-2, rtol=0)
641
        else:
642
            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
643
            self.assertRaisesRegex(RuntimeError,
644
                                   msg,
645
                                   lambda: m(x2))
646

647
    def _test_prelu_base(self, size, num_channels):
648
        x = torch.randn(size, dtype=torch.float32)
649
        x1 = x.clone().requires_grad_()
650
        x2 = x.clone().to_mkldnn().requires_grad_()
651
        x3 = x.clone().to_mkldnn().requires_grad_()
652
        m1 = torch.nn.PReLU(num_channels)
653
        m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1))
654
        m3 = copy.deepcopy(m1)
655
        y1 = m1(x1)
656
        y2 = m2(x2).to_dense()
657
        y3 = m3(x3).to_dense()  # Only convert data to mkldnn, weight is Aten tensor
658
        loss1 = y1.sum()
659
        loss1.backward()
660
        loss2 = y2.sum()
661
        loss2.backward()
662
        loss3 = y3.sum()
663
        loss3.backward()
664
        self.assertEqual(y1, y2)
665
        self.assertEqual(y1, y3)
666
        self.assertEqual(x1.grad, x2.grad.to_dense())
667
        self.assertEqual(x1.grad, x3.grad.to_dense())
668

669
    def test_prelu(self):
670
        self._test_prelu_base(torch.Size([16]), 1)
671
        self._test_prelu_base(torch.Size([16, 64]), 1)
672
        self._test_prelu_base(torch.Size([16, 64]), 64)
673
        self._test_prelu_base(torch.Size([16, 64, 112]), 1)
674
        self._test_prelu_base(torch.Size([16, 64, 112]), 64)
675
        self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1)
676
        self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64)
677
        self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1)
678
        self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64)
679

680
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
681
    def _test_prelu_bf16_base(self, size, num_channels):
682
        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
683
            x = torch.randn(size, dtype=torch.float32)
684
            x_fp32 = x.clone().to_mkldnn().requires_grad_()
685
            x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
686
            m = mkldnn_utils.to_mkldnn(torch.nn.PReLU())
687
            m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
688

689
            y = m(x_fp32).to_dense()
690
            y_bf16 = m_bf16(x_bf16).to_dense()
691
            self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3)
692

693
            loss = y.sum()
694
            loss.backward()
695
            loss_bf16 = y_bf16.sum()
696
            loss_bf16.backward()
697
            self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32))
698
        else:
699
            x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_()
700
            m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
701
            msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
702
            self.assertRaisesRegex(RuntimeError,
703
                                   msg,
704
                                   lambda: m_bf16(x_bf16))
705

706
    def test_prelu_bf16(self):
707
        self._test_prelu_bf16_base(torch.Size([16]), 1)
708
        self._test_prelu_bf16_base(torch.Size([16, 64]), 1)
709
        self._test_prelu_bf16_base(torch.Size([16, 64]), 64)
710
        self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1)
711
        self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64)
712
        self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1)
713
        self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64)
714

715
    def _test_max_pool_base(self, dim, input):
716
        pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
717
        for stride in [1, 2, 3]:
718
            for ceil_mode in [False, True]:
719
                max_pool = pool_module[dim](
720
                    kernel_size=3 if not ceil_mode else 7,
721
                    stride=stride,
722
                    padding=1,
723
                    ceil_mode=ceil_mode)
724

725
                x1 = input.clone().requires_grad_()
726
                x2 = input.clone().to_mkldnn().requires_grad_()
727
                y1 = max_pool(x1)
728
                y2 = max_pool(x2).to_dense()
729
                loss1 = y1.sum()
730
                loss2 = y2.sum()
731
                loss1.backward()
732
                loss2.backward()
733
                self.assertEqual(y1, y2)
734
                self.assertEqual(x1.grad, x2.grad.to_dense())
735

736
    def test_max_pool2d(self):
737
        N = torch.randint(3, 10, (1,)).item()
738
        C = torch.randint(3, 10, (1,)).item()
739
        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
740
            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
741
            self._test_max_pool_base(dim=2, input=x)
742

743
    def test_max_pool3d(self):
744
        N = torch.randint(3, 10, (1,)).item()
745
        C = torch.randint(3, 10, (1,)).item()
746
        for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
747
            x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
748
            self._test_max_pool_base(dim=3, input=x)
749

750

751
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
752
    def _test_max_pool_bf16_base(self, dim, input):
753
        pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
754
        x_bf16 = input.bfloat16()
755
        for stride in [1, 2, 3]:
756
            for ceil_mode in [False, True]:
757
                max_pool = pool_module[dim](
758
                    kernel_size=3 if not ceil_mode else 7,
759
                    stride=stride,
760
                    padding=1,
761
                    ceil_mode=ceil_mode)
762

763
                if torch.ops.mkldnn._is_mkldnn_bf16_supported():
764
                    y = max_pool(input.to_mkldnn()).to_dense()
765
                    y_bf16 = max_pool(x_bf16.to_mkldnn()).to_dense(torch.float32)
766
                    self.assertEqual(y, y_bf16, atol=0.1, rtol=1e-3)
767
                else:
768
                    msg = "mkldnn_max_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
769
                    self.assertRaisesRegex(RuntimeError,
770
                                           msg,
771
                                           lambda: max_pool(x_bf16.to_mkldnn()))
772

773
    def test_max_pool2d_bf16(self):
774
        N = torch.randint(3, 10, (1,)).item()
775
        C = torch.randint(3, 10, (1,)).item()
776
        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
777
            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
778
            self._test_max_pool_bf16_base(dim=2, input=x)
779

780
    def test_max_pool3d_bf16(self):
781
        N = torch.randint(3, 10, (1,)).item()
782
        C = torch.randint(3, 10, (1,)).item()
783
        for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
784
            x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
785
            self._test_max_pool_bf16_base(dim=3, input=x)
786

787
    def test_max_pool2d_stride_none(self):
788
        N = torch.randint(3, 10, (1,)).item()
789
        C = torch.randint(3, 10, (1,)).item()
790

791
        for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
792
            x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
793
            for ceil_mode in [False, True]:
794
                y1 = F.max_pool2d(
795
                    x,
796
                    kernel_size=3 if not ceil_mode else 7,
797
                    stride=None,
798
                    padding=1,
799
                    ceil_mode=ceil_mode)
800

801
                y2 = F.max_pool2d(
802
                    x.to_mkldnn(),
803
                    kernel_size=3 if not ceil_mode else 7,
804
                    stride=None,
805
                    padding=1,
806
                    ceil_mode=ceil_mode)
807

808
                self.assertEqual(y1, y2.to_dense())
809

810
    def test_max_pool_unsupported(self):
811
        # OneDNN not support dilation max_pooling, will be avilabled in v2.0.
812
        N = torch.randint(3, 10, (1,)).item()
813
        C = torch.randint(3, 10, (1,)).item()
814

815
        # 2d dilation case
816
        x = torch.randn(N, C, 7, 7, dtype=torch.float32).to_mkldnn()
817
        max_pool2d = torch.nn.MaxPool2d(
818
            kernel_size=3,
819
            stride=3,
820
            padding=1,
821
            dilation=2)
822
        self.assertRaisesRegex(RuntimeError,
823
                               'mkldnn_max_pool2d does not support dilation case',
824
                               lambda: max_pool2d(x))
825

826
        # 3d dilation case
827
        x = torch.randn(N, C, 7, 7, 7, dtype=torch.float32).to_mkldnn()
828
        max_pool3d = torch.nn.MaxPool3d(
829
            kernel_size=3,
830
            stride=3,
831
            padding=1,
832
            dilation=2)
833
        self.assertRaisesRegex(RuntimeError,
834
                               'mkldnn_max_pool3d does not support dilation case',
835
                               lambda: max_pool3d(x))
836

837
    def _test_avg_pool_base(self, dim, input):
838
        avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
839
        for count_include_pad in [True, False]:
840
            avg_pool = avg_module[dim](
841
                kernel_size=3,
842
                stride=2,
843
                padding=1,
844
                count_include_pad=count_include_pad)
845

846
            x1 = input.clone().requires_grad_()
847
            x2 = input.clone().to_mkldnn().requires_grad_()
848
            y1 = avg_pool(x1)
849
            y2 = avg_pool(x2).to_dense()
850
            loss1 = y1.sum()
851
            loss2 = y2.sum()
852
            loss1.backward()
853
            loss2.backward()
854
            self.assertEqual(y1, y2)
855
            self.assertEqual(x1.grad, x2.grad.to_dense())
856

857
    def test_avg_pool2d(self):
858
        N = torch.randint(3, 10, (1,)).item()
859
        C = torch.randint(3, 10, (1,)).item()
860
        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
861
        self._test_avg_pool_base(dim=2, input=x)
862

863
    def test_avg_pool3d(self):
864
        N = torch.randint(3, 10, (1,)).item()
865
        C = torch.randint(3, 10, (1,)).item()
866
        x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
867
        self._test_avg_pool_base(dim=3, input=x)
868

869
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
870
    def _test_avg_pool_bf16_base(self, dim, input):
871
        avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
872
        x_bf16 = input.bfloat16()
873
        for count_include_pad in [True, False]:
874
            avg_pool = avg_module[dim](
875
                kernel_size=3,
876
                stride=2,
877
                padding=1,
878
                count_include_pad=count_include_pad)
879
            if torch.ops.mkldnn._is_mkldnn_bf16_supported():
880
                y = avg_pool(input.to_mkldnn()).to_dense()
881
                y_bf16 = avg_pool(x_bf16.to_mkldnn()).to_dense(torch.float)
882
                self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
883
            else:
884
                msg = "mkldnn_avg_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
885
                self.assertRaisesRegex(RuntimeError,
886
                                       msg,
887
                                       lambda: avg_pool(x_bf16.to_mkldnn()))
888

889
    def test_avg_pool2d_bf16(self):
890
        N = torch.randint(3, 10, (1,)).item()
891
        C = torch.randint(3, 10, (1,)).item()
892
        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
893
        self._test_avg_pool_bf16_base(dim=2, input=x)
894

895
    def test_avg_pool3d_bf16(self):
896
        N = torch.randint(3, 10, (1,)).item()
897
        C = torch.randint(3, 10, (1,)).item()
898
        x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
899
        self._test_avg_pool_bf16_base(dim=3, input=x)
900

901
    def test_avg_pool2d_stride_none(self):
902
        N = torch.randint(3, 10, (1,)).item()
903
        C = torch.randint(3, 10, (1,)).item()
904
        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
905

906
        for count_include_pad in [True, False]:
907
            y1 = F.avg_pool2d(
908
                x,
909
                kernel_size=3,
910
                stride=None,
911
                padding=1,
912
                count_include_pad=count_include_pad)
913
            y2 = F.avg_pool2d(
914
                x.to_mkldnn(),
915
                kernel_size=3,
916
                stride=None,
917
                padding=1,
918
                count_include_pad=count_include_pad)
919

920
            self.assertEqual(y1, y2.to_dense())
921

922
    def test_adaptive_avg_pool2d(self):
923
        N = torch.randint(3, 10, (1,)).item()
924
        C = torch.randint(3, 10, (1,)).item()
925
        x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
926

927
        adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
928
        x1 = x.clone().requires_grad_()
929
        x2 = x.clone().to_mkldnn().requires_grad_()
930
        y1 = adaptive_avg_pool2d(x1)
931
        y2 = adaptive_avg_pool2d(x2).to_dense()
932

933
        loss1 = y1.sum()
934
        loss2 = y2.sum()
935
        loss1.backward()
936
        loss2.backward()
937

938
        self.assertEqual(y1, y2)
939
        self.assertEqual(x1.grad, x2.grad.to_dense())
940

941
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
942
    def test_adaptive_avg_pool2d_bf16(self):
943
        N = torch.randint(3, 10, (1,)).item()
944
        C = torch.randint(3, 10, (1,)).item()
945
        x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
946

947
        x_bf16 = x.bfloat16()
948
        adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
949

950
        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
951
            y = adaptive_avg_pool2d(x.to_mkldnn()).to_dense()
952
            y_bf16 = adaptive_avg_pool2d(x.to_mkldnn()).to_dense(torch.float32)
953
            self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
954
        else:
955
            msg = "mkldnn_adaptive_avg_pool2d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
956
            self.assertRaisesRegex(RuntimeError,
957
                                   msg,
958
                                   lambda: adaptive_avg_pool2d(x_bf16.to_mkldnn()))
959

960
    def _test_batch_norm_base(self, dim, channels, input):
961
        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
962
        bn = bn_module[dim](channels).float().train(False)
963
        mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
964
        self.assertEqual(
965
            bn(input),
966
            mkldnn_bn(input.to_mkldnn()).to_dense())
967

968
        self._test_serialization(mkldnn_bn, (input.to_mkldnn(),))
969
        self._test_tracing(mkldnn_bn, (input.to_mkldnn(),))
970

971
    def _test_batch_norm_train_base(self, dim, channels, input):
972
        # TODO: support 3d batchnorm training.
973
        bn_module = {2 : torch.nn.BatchNorm2d}
974
        # TODO: support none affine.
975
        options = itertools.product([True], [True, False])
976
        for affine, track_running_stats in options:
977
            bn = bn_module[dim](
978
                num_features=channels,
979
                affine=affine,
980
                track_running_stats=track_running_stats).float().train(True)
981
            mkldnn_bn = copy.deepcopy(bn)
982
            x1 = input.clone().requires_grad_()
983
            x2 = input.clone().to_mkldnn().requires_grad_()
984
            y1 = bn(x1)
985
            y2 = mkldnn_bn(x2).to_dense()
986
            loss1 = y1.sum()
987
            loss2 = y2.sum()
988
            loss1.backward()
989
            loss2.backward()
990
            self.assertEqual(y1, y2)
991
            self.assertEqual(x1.grad, x2.grad.to_dense())
992
            self.assertEqual(bn.weight.grad, mkldnn_bn.weight.grad, rtol=1e-3, atol=1e-3)
993
            if track_running_stats:
994
                self.assertEqual(bn.running_mean, mkldnn_bn.running_mean)
995
                self.assertEqual(bn.running_var, mkldnn_bn.running_var, rtol=1e-5, atol=1e-5)
996

997
    def test_batch_norm_2d(self):
998
        N = torch.randint(3, 10, (1,)).item()
999
        C = torch.randint(3, 100, (1,)).item()
1000
        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1001
        self._test_batch_norm_base(dim=2, channels=C, input=x)
1002
        self._test_batch_norm_train_base(dim=2, channels=C, input=x)
1003

1004
    def test_batch_norm_3d(self):
1005
        N = torch.randint(3, 10, (1,)).item()
1006
        C = torch.randint(3, 100, (1,)).item()
1007
        x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1008
        self._test_batch_norm_base(dim=3, channels=C, input=x)
1009

1010
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
1011
    def _test_batch_norm_bf16_base(self, dim, channels, input):
1012
        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
1013
        x_bf16 = input.bfloat16()
1014
        # TODO: support training
1015
        for train in [False]:
1016
            bn = bn_module[dim](channels).float().train(train)
1017
            mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
1018
            if torch.ops.mkldnn._is_mkldnn_bf16_supported():
1019
                y = bn(input.to_mkldnn().to_dense())
1020
                y_bf16 = bn(input.to_mkldnn().to_dense(torch.float))
1021
                self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
1022
            else:
1023
                msg = "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
1024
                self.assertRaisesRegex(RuntimeError,
1025
                                       msg,
1026
                                       lambda: bn(x_bf16.to_mkldnn()))
1027

1028
    def test_batch_norm_2d_bf16(self):
1029
        N = torch.randint(3, 10, (1,)).item()
1030
        C = torch.randint(3, 100, (1,)).item()
1031
        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1032
        self._test_batch_norm_bf16_base(dim=2, channels=C, input=x)
1033

1034
    def test_batch_norm_3d_bf16(self):
1035
        N = torch.randint(3, 10, (1,)).item()
1036
        C = torch.randint(3, 100, (1,)).item()
1037
        x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1038
        self._test_batch_norm_bf16_base(dim=3, channels=C, input=x)
1039

1040
    def test_add(self):
1041
        N = torch.randint(3, 10, (1,)).item()
1042
        C = torch.randint(3, 100, (1,)).item()
1043
        alpha = torch.randn(1, dtype=torch.float32).item()
1044

1045
        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1046
        y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1047
        mx = x.to_mkldnn()
1048
        my = y.to_mkldnn()
1049

1050
        # add
1051
        self.assertEqual(
1052
            x + y,
1053
            (mx + my).to_dense())
1054

1055
        self.assertEqual(
1056
            torch.add(x, y, alpha=alpha),
1057
            torch.add(mx, my, alpha=alpha).to_dense())
1058

1059
        # add_
1060
        x += y
1061
        mx += my
1062
        self.assertEqual(x, mx.to_dense())
1063

1064
        # add_out
1065
        out = x.clone()
1066
        mkldnn_out = out.to_mkldnn()
1067
        torch.add(x, y, alpha=alpha, out=out)
1068
        torch.add(mx, my, alpha=alpha, out=mkldnn_out)
1069
        self.assertEqual(out, mkldnn_out.to_dense())
1070

1071
        # add_out inplace case: first input
1072
        torch.add(x, y, alpha=alpha, out=x)
1073
        torch.add(mx, my, alpha=alpha, out=mx)
1074
        self.assertEqual(x, mx.to_dense())
1075

1076
        # add_out inplace case: second input
1077
        torch.add(x, y, alpha=alpha, out=y)
1078
        torch.add(mx, my, alpha=alpha, out=my)
1079
        self.assertEqual(y, my.to_dense())
1080

1081
    def test_mul(self):
1082
        N = torch.randint(3, 10, (1,)).item()
1083
        C = torch.randint(3, 100, (1,)).item()
1084
        value = torch.randn(1, dtype=torch.float32).item()
1085

1086
        x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1087
        y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1088
        mx = x.to_mkldnn()
1089
        my = y.to_mkldnn()
1090

1091
        # mul
1092
        self.assertEqual(
1093
            x * y,
1094
            (mx * my).to_dense())
1095

1096
        self.assertEqual(
1097
            x * value,
1098
            (mx * value).to_dense())
1099

1100
        self.assertEqual(
1101
            torch.mul(x, y),
1102
            torch.mul(mx, my).to_dense())
1103

1104
        self.assertEqual(
1105
            torch.mul(x, value),
1106
            torch.mul(mx, value).to_dense())
1107

1108
        # mul_
1109
        x *= y
1110
        mx *= my
1111
        self.assertEqual(x, mx.to_dense())
1112

1113
        x *= value
1114
        mx *= value
1115
        self.assertEqual(x, mx.to_dense())
1116

1117
        # mul_out
1118
        out = x.clone()
1119
        mkldnn_out = out.to_mkldnn()
1120
        torch.mul(x, y, out=out)
1121
        torch.mul(mx, my, out=mkldnn_out)
1122
        self.assertEqual(out, mkldnn_out.to_dense())
1123

1124
        out = x.clone()
1125
        mkldnn_out = out.to_mkldnn()
1126
        torch.mul(x, value, out=out)
1127
        torch.mul(mx, value, out=mkldnn_out)
1128
        self.assertEqual(out, mkldnn_out.to_dense())
1129

1130
    def test_0_dimension_tensor(self):
1131
        x = torch.rand([20, 20, 1, 1], dtype=torch.float)
1132
        y = torch.rand([20, 20, 0, 1], dtype=torch.float)
1133

1134
        # unary ops work without modification
1135
        out_relu = torch.relu(y)
1136
        out_relu_mkldnn = torch.relu(y.to_mkldnn()).to_dense()
1137
        self.assertEqual(out_relu, out_relu_mkldnn)
1138

1139
        out_mul = x * y
1140
        out_mul_mkldnn = (x.to_mkldnn() * y.to_mkldnn()).to_dense()
1141
        self.assertEqual(out_mul, out_mul_mkldnn)
1142

1143
        out_add = x + y
1144
        out_add_mkldnn = (x.to_mkldnn() + y.to_mkldnn()).to_dense()
1145
        self.assertEqual(out_add, out_add_mkldnn)
1146

1147
        x.requires_grad_(True)
1148
        y.requires_grad_(True)
1149
        with self.assertRaisesRegex(RuntimeError, "0-dimension Tensor in training"):
1150
            x.to_mkldnn() + y.to_mkldnn()
1151

1152
        with self.assertRaisesRegex(RuntimeError, "must match"):
1153
            torch.rand([5]).to_mkldnn() + torch.rand([0]).to_mkldnn()
1154

1155
        C = 7
1156
        m = torch.nn.Conv2d(C, C, 3)
1157
        x = torch.randn(0, C, C, 8, dtype=torch.float)
1158
        out_eager = m(x)
1159
        out_mkldnn = mkldnn_utils.to_mkldnn(m)(x)
1160
        self.assertEqual(out_eager, out_mkldnn)
1161

1162
    def test_view(self):
1163
        x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1164
        self.assertRaisesRegex(RuntimeError,
1165
                               "Change to use reshape",
1166
                               lambda: x.view(x.size(0), -1))
1167

1168
    def test_reshape(self):
1169
        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1170
        size = (x.size(0), -1)
1171

1172
        self.assertEqual(
1173
            x.reshape(size),
1174
            x.to_mkldnn().reshape(size).to_dense(),
1175
        )
1176
        # test whether share same memory for plain format tensor
1177
        y = x.to_mkldnn()
1178
        z = y.reshape(size).add_(y.reshape(size))
1179
        self.assertEqual(
1180
            y.reshape(size).to_dense(),
1181
            z.to_dense(),
1182
        )
1183

1184
    def test_reshape_blocked_format(self):
1185
        # construct an mkldnn blocked tensor with mkldnn conv2d
1186
        C = 7
1187
        m = mkldnn_utils.to_mkldnn(torch.nn.Conv2d(C, C, 3))
1188
        x = torch.randn(1, C, 8, 8).to_mkldnn()
1189

1190
        # mkldnn tensor w/ blocked format
1191
        y_block = m(x)
1192
        # aten tensor w/ plain format
1193
        y_plain = y_block.to_dense()
1194

1195
        y_block_reshape = y_block.reshape(C, -1)
1196
        y_plain_reshape = y_plain.reshape(C, -1)
1197

1198
        self.assertEqual(y_plain_reshape, y_block_reshape.to_dense())
1199

1200
    def test_reshape_backward(self):
1201
        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1202
        size = (x.size(0), -1)
1203

1204
        x1 = x.clone().requires_grad_()
1205
        x2 = x.clone().to_mkldnn().requires_grad_()
1206
        in_features = 20
1207
        out_features = torch.randint(3, 100, (1,)).item()
1208
        linear = torch.nn.Linear(in_features, out_features).float()
1209

1210
        y1 = linear(x1.reshape(size)).sum()
1211
        y2 = linear(x2.reshape(size).to_dense()).sum()
1212
        y1.backward()
1213
        y2.backward()
1214
        self.assertEqual(x1.grad, x2.grad.to_dense())
1215

1216
    def test_clone(self):
1217
        x = torch.randn(4, 5, dtype=torch.float32) * 10
1218
        self.assertEqual(
1219
            x.clone(),
1220
            x.to_mkldnn().clone().to_dense(),
1221
        )
1222
        # test whether share same memory
1223
        y = x.to_mkldnn()
1224
        z = y.clone().add_(y)
1225
        self.assertNotEqual(
1226
            y.to_dense(),
1227
            z.to_dense(),
1228
        )
1229

1230
    def test_transpose(self):
1231
        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1232
        for dim1 in range(x.ndim):
1233
            for dim2 in range(x.ndim):
1234
                self.assertEqual(
1235
                    x.transpose(dim1, dim2),
1236
                    x.to_mkldnn().transpose(dim1, dim2).to_dense(),
1237
                )
1238

1239
    def test_transpose_invalid_dime(self):
1240
        x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1241
        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1242
            torch._mkldnn_transpose(x, 0, 12)
1243

1244
    def test_linear_non_contiguous_weight(self):
1245
        in_features = torch.randint(3, 10, (1,)).item()
1246
        out_features = torch.randint(3, 100, (1,)).item()
1247
        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1248
        w = torch.randn(in_features, out_features, dtype=torch.float32)
1249
        for bias in [True, False]:
1250
            x1 = x.clone().requires_grad_()
1251
            x2 = x.clone().to_mkldnn().requires_grad_()
1252
            linear = torch.nn.Linear(in_features, out_features).float()
1253
            linear.weight = torch.nn.Parameter(w.t())
1254
            mkldnn_linear = copy.deepcopy(linear)
1255
            y1 = linear(x1).sum()
1256
            y2 = mkldnn_linear(x2).to_dense().sum()
1257
            y1.backward()
1258
            y2.backward()
1259
            self.assertEqual(x1.grad, x2.grad.to_dense())
1260
            self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1261
            if bias:
1262
                self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1263

1264
    def test_linear(self):
1265
        in_features = torch.randint(3, 10, (1,)).item()
1266
        out_features = torch.randint(3, 100, (1,)).item()
1267
        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1268

1269
        for bias in [True, False]:
1270
            linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1271
            mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1272
            self.assertEqual(
1273
                linear(x),
1274
                mkldnn_linear(x.to_mkldnn()).to_dense())
1275

1276
            self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
1277
            self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
1278

1279
    def test_linear_backward(self):
1280
        in_features = torch.randint(3, 10, (1,)).item()
1281
        out_features = torch.randint(3, 100, (1,)).item()
1282
        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1283
        for bias in [True, False]:
1284
            x1 = x.clone().requires_grad_()
1285
            x2 = x.clone().to_mkldnn().requires_grad_()
1286
            linear = torch.nn.Linear(in_features, out_features).float()
1287
            mkldnn_linear = copy.deepcopy(linear)
1288
            y1 = linear(x1).sum()
1289
            y2 = mkldnn_linear(x2).to_dense().sum()
1290
            y1.backward()
1291
            y2.backward()
1292
            self.assertEqual(x1.grad, x2.grad.to_dense())
1293
            self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1294
            if bias:
1295
                self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1296

1297
    @dtypes(torch.float16, torch.bfloat16)
1298
    def test_linear_lowp(self, dtype):
1299
        in_features = torch.randint(3, 10, (1,)).item()
1300
        out_features = torch.randint(3, 100, (1,)).item()
1301
        x = torch.randn(3, in_features, dtype=torch.float32) * 10
1302
        x_lowp = x.to(dtype=dtype)
1303

1304
        for bias in [True, False]:
1305
            linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1306
            mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1307
            mkldnn_linear_lowp = mkldnn_utils.to_mkldnn(
1308
                copy.deepcopy(linear), dtype
1309
            )
1310
            lowp_support = {
1311
                torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1312
                torch.half: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1313
            }
1314
            if lowp_support[dtype]():
1315
                y = mkldnn_linear(x.to_mkldnn()).to_dense()
1316
                y_lowp = mkldnn_linear_lowp(x_lowp.to_mkldnn()).to_dense(
1317
                    torch.float32
1318
                )
1319
                if dtype == torch.bfloat16:
1320
                    self.assertEqual(y, y_lowp, atol=1e-1, rtol=1e-3)
1321
                else:
1322
                    self.assertEqual(y, y_lowp, atol=5e-3, rtol=1e-3)
1323
            else:
1324
                msg = {
1325
                    torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
1326
                    torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
1327
                }
1328
                self.assertRaisesRegex(
1329
                    RuntimeError,
1330
                    msg[dtype],
1331
                    lambda: mkldnn_linear_lowp(x_lowp.to_mkldnn()),
1332
                )
1333

1334
    def test_softmax(self):
1335
        x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1336
        for dim in range(x.ndim):
1337
            softmax = torch.nn.Softmax(dim=dim)
1338
            self.assertEqual(
1339
                softmax(x),
1340
                softmax(x.to_mkldnn()).to_dense())
1341

1342
    def test_sigmoid(self):
1343
        x = torch.randn(4, 5, dtype=torch.float32) * 10
1344
        mkldnn_x = x.to_mkldnn()
1345
        self.assertEqual(
1346
            torch.sigmoid(x),
1347
            torch.sigmoid(mkldnn_x).to_dense(),
1348
        )
1349
        # inplace
1350
        torch.sigmoid_(x)
1351
        torch.sigmoid_(mkldnn_x)
1352
        self.assertEqual(x, mkldnn_x.to_dense())
1353

1354
    def test_tanh(self):
1355
        x = torch.randn(4, 5, dtype=torch.float32) * 10
1356
        mkldnn_x = x.to_mkldnn()
1357
        self.assertEqual(
1358
            torch.tanh(x),
1359
            torch.tanh(mkldnn_x).to_dense(),
1360
        )
1361
        # inplace
1362
        torch.tanh_(x)
1363
        torch.tanh_(mkldnn_x)
1364
        self.assertEqual(x, mkldnn_x.to_dense())
1365

1366
    def _test_serialization(self, module, inputs):
1367
        with TemporaryFileName() as fname:
1368
            torch.jit.save(module, fname)
1369
            loaded = torch.jit.load(fname)
1370
            self.assertEqual(
1371
                module(*inputs).to_dense(),
1372
                loaded(*inputs).to_dense())
1373

1374
    def _test_tracing(self, module, inputs):
1375
        traced = torch.jit.trace(module, inputs)
1376
        self.assertEqual(
1377
            module(*inputs).to_dense(),
1378
            traced(*inputs).to_dense())
1379

1380
    def test_set_data_tensorimpl_type(self):
1381
        # Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl
1382
        # of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`.
1383
        x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu'))
1384
        x_mkldnn = x.to_mkldnn()
1385
        with self.assertRaisesRegex(RuntimeError, 'incompatible tensor type'):
1386
            x.data = x_mkldnn
1387

1388
    def test_empty(self):
1389
        x1 = torch.empty(4, 5, 2, 3, dtype=torch.float32)
1390
        x2 = torch.empty(4, 5, 2, 3, dtype=torch.float32, layout=torch._mkldnn)
1391
        self.assertEqual(x1.size(), x2.to_dense().size())
1392
        self.assertEqual(x1.dtype, x2.to_dense().dtype)
1393

1394
    def test_zero_(self):
1395
        x1 = torch.randn(4, 5, dtype=torch.float32) * 10
1396
        x2 = x1.clone().to_mkldnn()
1397
        self.assertEqual(
1398
            x1.zero_(),
1399
            x2.zero_().to_dense(),
1400
        )
1401

1402
    def test_is_mkldnn(self):
1403
        x = torch.randn(1, dtype=torch.float32)
1404
        self.assertFalse(x.is_mkldnn)
1405
        self.assertTrue(x.to_mkldnn().is_mkldnn)
1406

1407
    # legacy constructor/new doesn't support mkldnn tensors
1408
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992")
1409
    def test_legacy_new_failure(self):
1410
        x = torch.randn(1, dtype=torch.float32)
1411
        x_mkldnn = x.to_mkldnn()
1412
        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(device='cpu'))
1413
        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x.storage()))
1414
        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x))
1415
        self.assertRaises(RuntimeError, lambda: x_mkldnn.new(torch.Size([2, 3])))
1416
        self.assertRaises(RuntimeError, lambda: x_mkldnn.new([6]))
1417

1418
    def test_is_mkldnn_jit(self):
1419
        class EnsureMkldnn(torch.jit.ScriptModule):
1420
            @torch.jit.script_method
1421
            def forward(self, x):
1422
                if not x.is_mkldnn:
1423
                    x = x.to_mkldnn()
1424
                return x
1425

1426
        m = EnsureMkldnn()
1427
        x = torch.randn(1, dtype=torch.float32)
1428
        self.assertTrue(m(x).is_mkldnn)
1429
        self.assertTrue(m(x.to_mkldnn()).is_mkldnn)
1430

1431
    def _test_imagenet_model(self, model):
1432
        model = model.train(False).float()
1433
        mkldnn_model = mkldnn_utils.to_mkldnn(copy.deepcopy(model))
1434
        x = torch.randn(1, 3, 224, 224, dtype=torch.float32)
1435
        with torch.no_grad():
1436
            self.assertEqual(
1437
                model(x),
1438
                mkldnn_model(x.to_mkldnn()).to_dense(),
1439
            )
1440

1441
    @skipIfNoTorchVision
1442
    def test_resnet18(self):
1443
        model = torchvision.models.resnet.resnet18(weights=None)
1444
        self._test_imagenet_model(model)
1445

1446
    @skipIfNoTorchVision
1447
    def test_resnext50_32x4d(self):
1448
        model = torchvision.models.resnet.resnext50_32x4d(weights=None)
1449
        self._test_imagenet_model(model)
1450

1451
    def _lstm_params_list(self):
1452
        params_dict = {
1453
            "input_size": [1, 5],
1454
            "hidden_size": [5, 16],
1455
            "num_layers": [1, 3],
1456
            "bidirectional": [False, True],
1457
            "bias": [False, True],
1458
            "batch_first": [False, True],
1459
            "dropout": [0, 0.4, 0.7, 1],
1460
            "batch_size": [1, 2],
1461
            "seq_len": [1, 3],
1462
            "training": [False, True]
1463
        }
1464

1465
        params_list = list(params_dict.values())
1466
        return params_list
1467

1468
    def _cast_dtype(self, input, bf16):
1469
        if bf16:
1470
            input = input.to(torch.bfloat16)
1471
        return input
1472

1473
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
1474
    def test_lstm(self):
1475
        seed = 2023
1476
        torch.manual_seed(seed)
1477

1478
        params_list = self._lstm_params_list()
1479
        for dtype in types:
1480
            bf16 = True if dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported() else False
1481
            rtol = 1.3e-6
1482
            atol = 1e-5
1483
            if bf16:
1484
                rtol = 0.02
1485
                atol = 0.02
1486
            for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \
1487
                    in itertools.product(*params_list):
1488
                num_directions = 2 if bidirectional else 1
1489
                if batch_first:
1490
                    input = torch.randn(batch_size, seq_len, input_size, dtype=torch.float32)
1491
                else:
1492
                    input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32)
1493
                h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1494
                c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1495

1496
                model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional,
1497
                                      bias=bias, dropout=dropout, batch_first=batch_first).float()
1498
                model.train() if training else model.eval()
1499
                input1 = input.clone().requires_grad_(training)
1500
                input2 = input.clone().requires_grad_(training)
1501

1502
                h1 = h.clone().requires_grad_(training)
1503
                h2 = h.clone().requires_grad_(training)
1504
                c1 = c.clone().requires_grad_(training)
1505
                c2 = c.clone().requires_grad_(training)
1506

1507
                model1 = copy.deepcopy(model)
1508
                model2 = copy.deepcopy(model)
1509
                with torch.cpu.amp.autocast(enabled=bf16, dtype=torch.bfloat16), torch.no_grad() if not training else nullcontext():
1510
                    with torch.backends.mkldnn.flags(enabled=False):
1511
                        torch.manual_seed(seed)
1512
                        output1, (hn1, cn1) = self._cast_dtype(model1, bf16)(self._cast_dtype(input1, bf16),
1513
                                                                             (self._cast_dtype(h1, bf16),
1514
                                                                             self._cast_dtype(c1, bf16)))
1515

1516
                    torch.manual_seed(seed)
1517
                    output2, (hn2, cn2) = model2(input2, (h2, c2))
1518
                    self.assertEqual(output1, output2, rtol=rtol, atol=atol)
1519
                    self.assertEqual(hn1, hn2, rtol=rtol, atol=atol)
1520
                    self.assertEqual(cn1, cn2, rtol=rtol, atol=atol)
1521

1522
                    if training:
1523
                        with torch.backends.mkldnn.flags(enabled=False):
1524
                            torch.manual_seed(seed)
1525
                            output1.sum().backward(retain_graph=True)
1526

1527
                        torch.manual_seed(seed)
1528
                        output2.sum().backward(retain_graph=True)
1529

1530
                        self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol)
1531
                        for name, para in model1.named_parameters():
1532
                            self.assertEqual(para, self._cast_dtype(getattr(model2, name), bf16))
1533
                            self.assertEqual(para.grad, self._cast_dtype(getattr(model2, name).grad, bf16), rtol=rtol, atol=atol)
1534

1535
                        with torch.backends.mkldnn.flags(enabled=False):
1536
                            torch.manual_seed(seed)
1537
                            hn1.sum().backward(retain_graph=True)
1538
                        torch.manual_seed(seed)
1539
                        hn2.sum().backward(retain_graph=True)
1540
                        self.assertEqual(h1.grad, h2.grad, rtol=rtol, atol=atol)
1541

1542
                        with torch.backends.mkldnn.flags(enabled=False):
1543
                            torch.manual_seed(seed)
1544
                            cn1.sum().backward(retain_graph=True)
1545
                        torch.manual_seed(seed)
1546
                        cn2.sum().backward(retain_graph=True)
1547
                        self.assertEqual(c1.grad, c2.grad, rtol=rtol, atol=atol)
1548

1549
    @dtypes(torch.float16, torch.bfloat16)
1550
    def test_matmul_lower_precision(self, dtype):
1551
        support_check = {
1552
            torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1553
            torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1554
        }
1555

1556
        def common(self, shape1, shape2, op, dtype):
1557
            a = torch.randn(shape1, dtype=dtype)
1558
            a_ref = a.float()
1559
            b = torch.randn(shape2, dtype=dtype)
1560
            b_ref = b.float()
1561

1562
            y = op(a, b)
1563
            y_ref = op(a_ref, b_ref)
1564
            self.assertEqual(y, y_ref, exact_dtype=False)
1565

1566
        if support_check[dtype]():
1567
            a1 = torch.randn([64, 1, 33], dtype=dtype)
1568
            # a2 is contiguous tensor but it's strides
1569
            # is not default contiguous strides.
1570
            a2 = torch.as_strided(a1.clone(), [64, 1, 33], [33, 3, 1])
1571
            self.assertTrue(a2.is_contiguous())
1572
            b = torch.randn(64, 33, 256).to(dtype=dtype)
1573
            y1 = torch.ops.aten.bmm(a1, b)
1574
            y2 = torch.bmm(a2, b)
1575
            self.assertEqual(y1, y2)
1576

1577
            for shape1, shape2, op in [
1578
                ((33, 77), (77, 22), torch.matmul),
1579
                ((128, 256), (256, 10), torch.matmul),
1580
                ((7, 300), (300, 3), torch.matmul),
1581
                ((1, 100), (100, 60), torch.matmul),
1582
                ((100, 1), (1, 100), torch.matmul),
1583
                ((20, 54, 78), (20, 78, 10), torch.bmm),
1584
                ((1, 300, 1), (1, 1, 300), torch.bmm),
1585
            ]:
1586
                common(self, shape1, shape2, op, dtype)
1587

1588

1589
instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))
1590

1591
if __name__ == '__main__':
1592
    run_tests()
1593

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.