1
# Owner(s): ["module: mkldnn"]
7
from contextlib import nullcontext
11
HAS_TORCHVISION = True
13
HAS_TORCHVISION = False
15
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
18
import torch.nn.functional as F
20
import torch.backends.mkldnn
21
from torch.utils import mkldnn as mkldnn_utils
22
from torch.testing._internal.common_utils import TestCase, \
23
run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \
25
from torch.testing._internal.common_device_type import (
26
instantiate_device_type_tests,
30
# batched grad doesn't support mkldnn
31
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
32
gradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False)
35
types = [torch.float, torch.bfloat16, torch.half]
37
# Comment the line below to find out the CI machines having MKL-DNN build disabled
38
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
39
class TestMkldnn(TestCase):
40
def test_conversion(self):
41
for cpu_tensor in [torch.randn((1, 2, 3, 4),
42
dtype=torch.float, device=torch.device('cpu')),
43
torch.randn((1, 2, 3, 4, 5),
44
dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]:
45
cpu_tensor.requires_grad_()
46
convert_dtypes = {torch.half: [torch.half, torch.float],
47
torch.bfloat16: [torch.bfloat16, torch.float],
48
torch.float: [torch.bfloat16, torch.half]}
49
# float/bfloat16/half cpu tensor to mkldnn tensortensor.
51
mkldnn_tensor = cpu_tensor.to_mkldnn(dtype1)
52
self.assertEqual(mkldnn_tensor.dtype, dtype1)
53
cpu_tensor_1 = mkldnn_tensor.to_dense()
54
# not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
55
self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
56
# mkldnn float/bfloat tensor to cpu float or bfloat tensor
57
for dtype2 in convert_dtypes[dtype1]:
58
cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
59
self.assertEqual(cpu_tensor_2.dtype, dtype2)
60
atol = 1e-5 if dtype1 == torch.float and dtype2 == torch.float else 1e-2
61
self.assertEqual(cpu_tensor, cpu_tensor_2.float(), atol=atol, rtol=0)
63
self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
64
self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
65
self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
66
if dtype1 == torch.float:
67
self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
69
self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size() / 2)
70
self.assertRaisesRegex(RuntimeError,
71
"Cannot access data pointer of Tensor that doesn't have storage",
72
lambda: mkldnn_tensor.data_ptr() != 0)
74
# bfloat cpu tensor to mkldnn float tensor or bfloat tensor.
75
for orig_dtype in [torch.half, torch.bfloat16]:
76
cpu_tensor_lower = cpu_tensor.to(dtype=orig_dtype)
77
for dtype1 in convert_dtypes[orig_dtype]:
78
mkldnn_tensor = cpu_tensor_lower.to_mkldnn(dtype1)
79
self.assertEqual(mkldnn_tensor.dtype, dtype1)
80
cpu_tensor_1 = mkldnn_tensor.to_dense()
81
# not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor
82
self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
83
# mkldnn float/bfloat/half tensor to cpu float/bfloat/half tensor
84
for dtype2 in convert_dtypes[cpu_tensor_lower.dtype]:
85
cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2)
86
self.assertEqual(cpu_tensor_2.dtype, dtype2)
87
self.assertEqual(cpu_tensor_lower,
88
cpu_tensor_2.to(dtype=cpu_tensor_lower.dtype), atol=1e-5, rtol=0)
90
self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
91
self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
92
self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
93
if dtype1 in [torch.bfloat16, torch.half]:
94
self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size())
96
self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_lower.element_size() * 2)
97
self.assertRaisesRegex(RuntimeError,
98
"Cannot access data pointer of Tensor that doesn't have storage",
99
lambda: mkldnn_tensor.data_ptr() != 0)
101
def test_conversion_byte_char(self):
102
int8_types = [torch.int8, torch.uint8]
103
for int8_type in int8_types:
104
low = -100 if int8_type is torch.int8 else 0
106
for cpu_tensor in [torch.randint(
111
device=torch.device('cpu')),
115
size=(1, 2, 3, 4, 5),
117
device=torch.device('cpu'))[:, :, :, :, :]]:
119
cpu_tensor = cpu_tensor.to(dtype=int8_type)
120
mkldnn_tensor = cpu_tensor.to_mkldnn(int8_type)
121
self.assertEqual(mkldnn_tensor.dtype, int8_type)
122
cpu_tensor_1 = mkldnn_tensor.to_dense()
123
self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype)
124
self.assertEqual(cpu_tensor, cpu_tensor_1)
125
self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
126
self.assertEqual(mkldnn_tensor.size(), cpu_tensor.size())
127
self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
128
self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
129
self.assertRaisesRegex(RuntimeError,
130
"Cannot access data pointer of Tensor that doesn't have storage",
131
lambda: mkldnn_tensor.data_ptr() != 0)
134
x = torch.randn(4, 5, dtype=torch.float32)
135
mkldnn_x = x.to_mkldnn()
136
mkldnn_y = torch.randn(4, 5, dtype=torch.float32).to_mkldnn()
137
mkldnn_z = torch.randn(4, 10, dtype=torch.float32).to_mkldnn()
138
mkldnn_y.copy_(mkldnn_x)
139
self.assertEqual(x, mkldnn_y.to_dense())
140
self.assertRaisesRegex(RuntimeError,
141
"copy_mkldnn_: only support same size tensor.",
142
lambda: mkldnn_z.copy_(mkldnn_x))
143
self.assertRaisesRegex(RuntimeError,
144
"copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
145
"Found self type = torch.FloatTensor and src type = Mkldnntorch.FloatTensor",
146
lambda: x.copy_(mkldnn_x))
147
self.assertRaisesRegex(RuntimeError,
148
"copy_mkldnn_: between mkldnn layout and dense Tensors is not implemented! "
149
"Found self type = Mkldnntorch.FloatTensor and src type = torch.FloatTensor",
150
lambda: mkldnn_x.copy_(x))
152
def test_unsupported(self):
153
# unsupported types and unsupported types with gpu
154
for dtype in [torch.double, torch.uint8, torch.int8,
155
torch.short, torch.int, torch.long]:
156
with self.assertRaises(RuntimeError) as context:
157
torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn()
158
if torch.cuda.is_available():
159
with self.assertRaises(RuntimeError) as context:
160
torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn()
161
# supported type with gpu
162
if torch.cuda.is_available():
163
with self.assertRaises(RuntimeError) as context:
164
torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn()
165
# some factory functions
166
for creator in [torch.ones, torch.randn, torch.rand]:
167
with self.assertRaises(RuntimeError) as context:
168
creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn)
170
def test_mkldnn_conv_shapecheck(self):
171
input = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
172
w1 = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32)
173
b1 = torch.full((1,), 1, dtype=torch.float32)
174
w2 = torch.full((1, 1, 2, 24,), 1, dtype=torch.float32)
175
b2 = torch.full((2,), 1, dtype=torch.float32)
176
options = zip([-1, 0, 0, 0, 0, 0, 0], # padding
177
[1, 0, 1, 1, 1, 1, 1], # stride
178
[1, 1, 0, 1, 1, 1, 1], # dilation
179
[1, 1, 1, 0, 2, 1, 1], # groups
180
[w1, w1, w1, w1, w1, w1, w2], # weight
181
[b1, b1, b1, b1, b1, b2, b1]) # bias
182
for pad, st, dil, gr, w, b in options:
183
with self.assertRaises(RuntimeError) as _:
184
torch.mkldnn_convolution(input, w, b, [pad] * 2, [st] * 2, [dil] * 2, gr)
186
def test_autograd_to_mkldnn(self):
187
# MKLDNN only supports float32
188
root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
191
return root.to_mkldnn().to_dense()
193
# because MKLDNN only supports float32, we need to lessen the precision.
194
# these numbers are just empirical results that seem to work.
195
self.assertWarnsRegex(UserWarning,
196
'double precision floating point',
197
lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
198
self.assertWarnsRegex(UserWarning,
199
'double precision floating point',
200
lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
202
def test_autograd_from_mkldnn(self):
203
# MKLDNN only supports float32
204
root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
207
return root.to_dense()
209
# because MKLDNN only supports float32, we need to lessen the precision.
210
# these numbers are just empirical results that seem to work.
211
self.assertWarnsRegex(UserWarning,
212
'double precision floating point',
213
lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
215
def test_detach(self):
216
root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
218
detach = root.detach()
219
self.assertEqual((4, 5), detach.size())
220
self.assertFalse(detach.requires_grad)
221
self.assertTrue(root.requires_grad)
223
detach_ = root.detach_()
224
self.assertEqual((4, 5), detach_.size())
225
self.assertFalse(detach_.requires_grad)
226
self.assertFalse(root.requires_grad)
229
self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4),
230
dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
232
def _test_conv_base(self, dim):
233
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
234
input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
235
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
236
for train, bias, dilation, groups in options:
237
N = torch.randint(3, 10, (1,)).item()
238
M = torch.randint(1, 3, (1,)).item() * groups
239
C = torch.randint(1, 3, (1,)).item() * groups
240
x_shape = (N, C) + input_shapes[dim]
241
x = torch.randn(x_shape, dtype=torch.float32)
242
conv = conv_module[dim](in_channels=C,
249
groups=groups).float()
251
x2 = x.clone().to_mkldnn()
253
mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
254
elif train and dim != 1:
255
# TODO: enable conv1d training.
258
mkldnn_conv = copy.deepcopy(conv)
259
with torch.backends.mkldnn.flags(enabled=False):
261
if train and dim != 1:
264
if not train or (train and dim != 1):
265
y_mkldnn = mkldnn_conv(x2).to_dense()
266
self.assertEqual(y_aten, y_mkldnn)
268
self._test_serialization(mkldnn_conv, (x.to_mkldnn(),))
269
self._test_tracing(mkldnn_conv, (x.to_mkldnn(),))
271
loss2 = y_mkldnn.sum()
273
self.assertTrue(x2.grad.is_mkldnn)
274
self.assertEqual(x1.grad, x2.grad.to_dense())
275
self.assertEqual(conv.weight.grad,
276
mkldnn_conv.weight.grad,
280
self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad)
282
def test_conv1d(self):
283
self._test_conv_base(dim=1)
285
def test_conv2d(self):
286
self._test_conv_base(dim=2)
288
def test_conv3d(self):
289
self._test_conv_base(dim=3)
291
def _test_conv_deconv_lower_precision_base(self, dim, conv_module, dtype):
292
input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
293
options = itertools.product([True, False], [1, 2], [1, 4])
294
for bias, dilation, groups in options:
295
N = torch.randint(1, 3, (1,)).item()
296
M = torch.randint(1, 3, (1,)).item() * groups
297
C = torch.randint(1, 3, (1,)).item() * groups
298
x_shape = (N, C) + input_shapes[dim]
299
x = torch.randn(x_shape, dtype=torch.float32)
300
# TODO: remove this when group depthwise is supported:
301
if conv_module in [torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
302
torch.nn.ConvTranspose3d] and groups > 1 and C == groups:
304
conv = conv_module(in_channels=C,
311
groups=groups).float()
312
x_lower = x.to(dtype=dtype)
313
if (dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported()) or \
314
(dtype == torch.half and torch.ops.mkldnn._is_mkldnn_fp16_supported()):
315
mkldnn_conv = mkldnn_utils.to_mkldnn(copy.deepcopy(conv))
316
mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
317
y = mkldnn_conv(x.to_mkldnn()).to_dense()
318
y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
319
self.assertEqual(y, y_lower, atol=1e-1, rtol=1e-3)
322
torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
323
torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
325
with self.assertRaisesRegex(RuntimeError, msg[dtype]):
326
mkldnn_conv_lower = mkldnn_utils.to_mkldnn(copy.deepcopy(conv), dtype)
327
y_lower = mkldnn_conv_lower(x_lower.to_mkldnn()).to_dense(torch.float32)
329
conv_lower = copy.deepcopy(conv).to(dtype=dtype)
330
conv_ref = copy.deepcopy(conv_lower).float()
331
with torch.backends.mkldnn.flags(enabled=False):
332
x_ref = x_lower.clone().float().detach().requires_grad_()
333
x_lower.requires_grad_()
335
y_lower = conv_lower(x_lower).float()
336
self.assertEqual(y, y_lower, atol=5e-2, rtol=5e-3)
338
@dtypes(torch.float16, torch.bfloat16)
339
def test_conv_deconv_1d_lower_precision(self, dtype):
340
self._test_conv_deconv_lower_precision_base(1, torch.nn.Conv1d, dtype=dtype)
341
self._test_conv_deconv_lower_precision_base(1, torch.nn.ConvTranspose1d, dtype=dtype)
343
@dtypes(torch.float16, torch.bfloat16)
344
def test_conv_deconv_2d_lower_precision(self, dtype):
345
self._test_conv_deconv_lower_precision_base(2, torch.nn.Conv2d, dtype=dtype)
346
self._test_conv_deconv_lower_precision_base(2, torch.nn.ConvTranspose2d, dtype=dtype)
348
@dtypes(torch.float16, torch.bfloat16)
349
def test_conv_deconv_3d_lower_precision(self, dtype):
350
self._test_conv_deconv_lower_precision_base(3, torch.nn.Conv3d, dtype=dtype)
351
self._test_conv_deconv_lower_precision_base(3, torch.nn.ConvTranspose3d, dtype=dtype)
353
def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None):
354
input_shapes = {2: (55, 55), 3: (14, 14, 14)}
355
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
356
if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
357
cl_format = torch.channels_last
358
input_shape = input_shapes[2]
359
elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]:
360
cl_format = torch.channels_last_3d
361
input_shape = input_shapes[3]
363
for train, bias, dilation, groups in options:
364
N = torch.randint(3, 10, (1,)).item()
365
M = torch.randint(1, 3, (1,)).item() * groups
366
C = torch.randint(1, 3, (1,)).item() * groups
367
x_shape = (N, C) + input_shape
368
x = torch.randn(x_shape, dtype=dtype)
370
# conv1: mkldnn conv/deconv in contiguous memory format (nchw)
371
# conv2: mkldnn conv/deconv in channels last memory format (nhwc)
372
conv1 = conv_module(in_channels=C,
379
groups=groups).to(dtype=dtype)
380
conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format)
382
x2 = x.clone().to(memory_format=cl_format)
388
self.assertEqual(y1, y2, atol=prec, rtol=prec)
393
self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format))
394
self.assertEqual(conv1.weight.grad,
399
self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
400
self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
402
def test_conv_nhwc_fp32(self):
403
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
404
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
405
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32)
406
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32)
408
@dtypes(torch.float16, torch.bfloat16)
409
def test_conv_nhwc_lower_precision(self, dtype):
410
# when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
411
# returns false, bf16/fp16 CPU conv will fall back to thnn impl
413
torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
414
torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
416
if support_checks[dtype]():
417
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype)
418
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype)
419
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype)
420
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype)
422
# BF16/FP16 fallback implementations are divided into two parts im2col+gemm,
423
# and the number of data type conversions in the middle is more than that of onednn's direct conv,
424
# resulting in additional accuracy loss.
426
torch.bfloat16: 1e-2,
429
prec = precisions[dtype]
430
with torch.backends.mkldnn.flags(enabled=False):
431
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=dtype, prec=prec)
432
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=dtype, prec=prec)
433
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=dtype, prec=prec)
434
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec)
437
def test_conv_transpose_nhwc_fp32(self):
438
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
439
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
440
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32)
441
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32)
443
@dtypes(torch.float16, torch.bfloat16)
444
def test_conv_transpose_nhwc_lower_precision(self, dtype):
445
# when torch.ops.mkldnn._is_mkldnn_bf16_supported() or torch.ops.mkldnn._is_mkldnn_fp16_supported()
446
# returns false, bf16/fp16 CPU conv will fall back to thnn impl
448
torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
449
torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported
451
if support_checks[dtype]():
452
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype)
453
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype)
454
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype)
455
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype)
457
# BF16/FP16 fallback implementations are divided into two parts col2im+gemm,
458
# and the number of data type conversions in the middle is more than that of onednn's direct conv,
459
# resulting in additional accuracy loss.
461
torch.bfloat16: 2e-2,
464
prec = precisions[dtype]
465
with torch.backends.mkldnn.flags(enabled=False):
466
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=dtype, prec=prec)
467
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=dtype, prec=prec)
468
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=dtype, prec=prec)
469
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=dtype, prec=prec)
471
def _test_conv_transpose_base(self, dim):
473
1: torch.nn.ConvTranspose1d,
474
2: torch.nn.ConvTranspose2d,
475
3: torch.nn.ConvTranspose3d
477
input_shapes = {1: (55,), 2: (28, 28), 3: (14, 14, 14)}
478
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
479
for train, bias, dilation, groups in options:
480
N = torch.randint(3, 10, (1,)).item()
481
M = torch.randint(1, 3, (1,)).item() * groups
482
C = torch.randint(1, 3, (1,)).item() * groups
483
x_shape = (N, C) + input_shapes[dim]
484
data = torch.randn(x_shape, dtype=torch.float32)
485
# conv: mkldnn tranpose conv fp32
486
# conv_ref: thnn transpose conv fp32
487
conv = conv_module[dim](in_channels=C,
494
groups=groups).to(dtype=torch.float32)
499
x_ref.requires_grad_()
501
conv_ref = copy.deepcopy(conv)
502
with torch.backends.mkldnn.flags(enabled=False):
503
y_ref = conv_ref(x_ref)
505
y_ref.sum().backward()
511
self.assertEqual(y, y_ref)
513
self.assertEqual(x.grad, x_ref.grad)
514
self.assertEqual(conv.weight.grad,
515
conv_ref.weight.grad,
519
self.assertEqual(conv.bias.grad, conv_ref.bias.grad)
521
def test_conv_transpose1d(self):
522
self._test_conv_transpose_base(dim=1)
524
def test_conv_transpose2d(self):
525
self._test_conv_transpose_base(dim=2)
527
def test_conv_transpose3d(self):
528
self._test_conv_transpose_base(dim=3)
530
def test_conv2d_legacy_jit_model(self):
532
MKLDNN integration used to serialize models with 5d weight for grouped
533
convolutions, we'd like to preserve this behavior
536
conv2d = torch.nn.Conv2d(16, 16, 3, groups=g)
537
conv2d_mkldnn = torch.utils.mkldnn.to_mkldnn(conv2d)
539
# contrive legacy conv2d module with a 5-d weight
540
o, i, h, w = conv2d.weight.shape
541
weight_5d = conv2d.weight.reshape((g, o // g, i, h, w))
542
conv2d_mkldnn.weight = weight_5d.to_mkldnn()
544
x = torch.randn(1, 16, 8, 8)
546
with TemporaryFileName() as fname:
547
torch.jit.save(conv2d_mkldnn, fname)
548
conv2d_loaded = torch.jit.load(fname)
550
self.assertEqual(conv2d_mkldnn.weight.ndimension(), 5)
551
self.assertEqual(conv2d_loaded.weight.ndimension(), 4)
554
conv2d_loaded(x.to_mkldnn()).to_dense())
556
# This test is to check whether 1D conv is supported for mkldnn tensor,
557
# which is exposed by Issue https://github.com/pytorch/pytorch/issues/68034.
558
def test_conv1d_functional(self):
559
input = torch.randn(2, 3, 10).to_mkldnn()
560
weight = torch.randn(3, 3, 3).to_mkldnn()
561
bias = torch.randn(3).to_mkldnn()
562
output = torch.nn.functional.conv1d(input, weight, bias)
563
self.assertEqual(output.size(), torch.Size([2, 3, 8]))
566
x = torch.randn((4, 5), dtype=torch.float32) * 10
567
x1 = x.clone().requires_grad_()
568
x2 = x.clone().to_mkldnn().requires_grad_()
570
y2 = torch.relu(x2).to_dense()
575
self.assertEqual(y1, y2)
576
self.assertEqual(x1.grad, x2.grad.to_dense())
578
def test_relu_(self):
579
x = torch.randn((4, 5), dtype=torch.float32) * 10
580
x1 = x.clone().requires_grad_()
581
x2 = x.clone().to_mkldnn().requires_grad_()
582
y1 = torch.relu_(x1.clone())
583
y2 = torch.relu_(x2.clone()).to_dense()
588
self.assertEqual(y1, y2)
589
self.assertEqual(x1.grad, x2.grad.to_dense())
591
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
592
def _test_relu_bf16_base(self, name):
593
x = torch.randn((4, 5), dtype=torch.float32) * 10
594
x_bf16 = x.bfloat16()
595
fn = getattr(torch, name)
596
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
597
y = fn(x.to_mkldnn()).to_dense()
598
y_bf16 = fn(x_bf16.to_mkldnn()).to_dense(torch.float32)
599
self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
601
msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
602
self.assertRaisesRegex(RuntimeError,
604
lambda: fn(x_bf16.to_mkldnn()))
606
def test_relu_bf16(self):
607
self._test_relu_bf16_base("relu")
609
def test_relu_inplace_bf16(self):
610
self._test_relu_bf16_base("relu_")
614
x = torch.randn((4, 5), dtype=torch.float32) * 10
615
x1 = x.clone().requires_grad_()
616
x2 = x.clone().to_mkldnn().requires_grad_()
618
y2 = m(x2).to_dense()
623
self.assertEqual(y1, y2)
624
self.assertEqual(x1.grad, x2.grad.to_dense())
626
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
627
def test_gelu_bf16(self):
629
x = torch.randn((4, 5), dtype=torch.float32) * 10
630
x1 = x.clone().to_mkldnn().requires_grad_()
631
x2 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
632
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
633
y1 = m(x1).to_dense()
634
y2 = m(x2).to_dense()
639
self.assertEqual(y1, y2.to(torch.float32), atol=1e-1, rtol=0)
640
self.assertEqual(x1.grad.to_dense(), x2.grad.to_dense(torch.float32), atol=1e-2, rtol=0)
642
msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
643
self.assertRaisesRegex(RuntimeError,
647
def _test_prelu_base(self, size, num_channels):
648
x = torch.randn(size, dtype=torch.float32)
649
x1 = x.clone().requires_grad_()
650
x2 = x.clone().to_mkldnn().requires_grad_()
651
x3 = x.clone().to_mkldnn().requires_grad_()
652
m1 = torch.nn.PReLU(num_channels)
653
m2 = mkldnn_utils.to_mkldnn(copy.deepcopy(m1))
654
m3 = copy.deepcopy(m1)
656
y2 = m2(x2).to_dense()
657
y3 = m3(x3).to_dense() # Only convert data to mkldnn, weight is Aten tensor
664
self.assertEqual(y1, y2)
665
self.assertEqual(y1, y3)
666
self.assertEqual(x1.grad, x2.grad.to_dense())
667
self.assertEqual(x1.grad, x3.grad.to_dense())
669
def test_prelu(self):
670
self._test_prelu_base(torch.Size([16]), 1)
671
self._test_prelu_base(torch.Size([16, 64]), 1)
672
self._test_prelu_base(torch.Size([16, 64]), 64)
673
self._test_prelu_base(torch.Size([16, 64, 112]), 1)
674
self._test_prelu_base(torch.Size([16, 64, 112]), 64)
675
self._test_prelu_base(torch.Size([16, 64, 112, 112]), 1)
676
self._test_prelu_base(torch.Size([16, 64, 112, 112]), 64)
677
self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 1)
678
self._test_prelu_base(torch.Size([16, 64, 112, 112, 1]), 64)
680
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
681
def _test_prelu_bf16_base(self, size, num_channels):
682
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
683
x = torch.randn(size, dtype=torch.float32)
684
x_fp32 = x.clone().to_mkldnn().requires_grad_()
685
x_bf16 = x.clone().to_mkldnn(torch.bfloat16).requires_grad_()
686
m = mkldnn_utils.to_mkldnn(torch.nn.PReLU())
687
m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
689
y = m(x_fp32).to_dense()
690
y_bf16 = m_bf16(x_bf16).to_dense()
691
self.assertEqual(y, y_bf16.to(torch.float32), atol=1e-1, rtol=1e-3)
695
loss_bf16 = y_bf16.sum()
697
self.assertEqual(x_fp32.grad.to_dense(), x_bf16.grad.to_dense(torch.float32))
699
x_bf16 = torch.randn(size, dtype=torch.bfloat16).requires_grad_()
700
m_bf16 = mkldnn_utils.to_mkldnn(torch.nn.PReLU(), torch.bfloat16)
701
msg = r"bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
702
self.assertRaisesRegex(RuntimeError,
704
lambda: m_bf16(x_bf16))
706
def test_prelu_bf16(self):
707
self._test_prelu_bf16_base(torch.Size([16]), 1)
708
self._test_prelu_bf16_base(torch.Size([16, 64]), 1)
709
self._test_prelu_bf16_base(torch.Size([16, 64]), 64)
710
self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 1)
711
self._test_prelu_bf16_base(torch.Size([16, 64, 112]), 64)
712
self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 1)
713
self._test_prelu_bf16_base(torch.Size([16, 64, 112, 112, 1]), 64)
715
def _test_max_pool_base(self, dim, input):
716
pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
717
for stride in [1, 2, 3]:
718
for ceil_mode in [False, True]:
719
max_pool = pool_module[dim](
720
kernel_size=3 if not ceil_mode else 7,
725
x1 = input.clone().requires_grad_()
726
x2 = input.clone().to_mkldnn().requires_grad_()
728
y2 = max_pool(x2).to_dense()
733
self.assertEqual(y1, y2)
734
self.assertEqual(x1.grad, x2.grad.to_dense())
736
def test_max_pool2d(self):
737
N = torch.randint(3, 10, (1,)).item()
738
C = torch.randint(3, 10, (1,)).item()
739
for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
740
x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
741
self._test_max_pool_base(dim=2, input=x)
743
def test_max_pool3d(self):
744
N = torch.randint(3, 10, (1,)).item()
745
C = torch.randint(3, 10, (1,)).item()
746
for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
747
x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
748
self._test_max_pool_base(dim=3, input=x)
751
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
752
def _test_max_pool_bf16_base(self, dim, input):
753
pool_module = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
754
x_bf16 = input.bfloat16()
755
for stride in [1, 2, 3]:
756
for ceil_mode in [False, True]:
757
max_pool = pool_module[dim](
758
kernel_size=3 if not ceil_mode else 7,
763
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
764
y = max_pool(input.to_mkldnn()).to_dense()
765
y_bf16 = max_pool(x_bf16.to_mkldnn()).to_dense(torch.float32)
766
self.assertEqual(y, y_bf16, atol=0.1, rtol=1e-3)
768
msg = "mkldnn_max_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
769
self.assertRaisesRegex(RuntimeError,
771
lambda: max_pool(x_bf16.to_mkldnn()))
773
def test_max_pool2d_bf16(self):
774
N = torch.randint(3, 10, (1,)).item()
775
C = torch.randint(3, 10, (1,)).item()
776
for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
777
x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
778
self._test_max_pool_bf16_base(dim=2, input=x)
780
def test_max_pool3d_bf16(self):
781
N = torch.randint(3, 10, (1,)).item()
782
C = torch.randint(3, 10, (1,)).item()
783
for D, H, W in [(64, 64, 64), (35, 39, 35), (16, 19, 20), [7, 8, 9]]:
784
x = torch.randn(N, C, D, H, W, dtype=torch.float32) * 10
785
self._test_max_pool_bf16_base(dim=3, input=x)
787
def test_max_pool2d_stride_none(self):
788
N = torch.randint(3, 10, (1,)).item()
789
C = torch.randint(3, 10, (1,)).item()
791
for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
792
x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
793
for ceil_mode in [False, True]:
796
kernel_size=3 if not ceil_mode else 7,
803
kernel_size=3 if not ceil_mode else 7,
808
self.assertEqual(y1, y2.to_dense())
810
def test_max_pool_unsupported(self):
811
# OneDNN not support dilation max_pooling, will be avilabled in v2.0.
812
N = torch.randint(3, 10, (1,)).item()
813
C = torch.randint(3, 10, (1,)).item()
816
x = torch.randn(N, C, 7, 7, dtype=torch.float32).to_mkldnn()
817
max_pool2d = torch.nn.MaxPool2d(
822
self.assertRaisesRegex(RuntimeError,
823
'mkldnn_max_pool2d does not support dilation case',
824
lambda: max_pool2d(x))
827
x = torch.randn(N, C, 7, 7, 7, dtype=torch.float32).to_mkldnn()
828
max_pool3d = torch.nn.MaxPool3d(
833
self.assertRaisesRegex(RuntimeError,
834
'mkldnn_max_pool3d does not support dilation case',
835
lambda: max_pool3d(x))
837
def _test_avg_pool_base(self, dim, input):
838
avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
839
for count_include_pad in [True, False]:
840
avg_pool = avg_module[dim](
844
count_include_pad=count_include_pad)
846
x1 = input.clone().requires_grad_()
847
x2 = input.clone().to_mkldnn().requires_grad_()
849
y2 = avg_pool(x2).to_dense()
854
self.assertEqual(y1, y2)
855
self.assertEqual(x1.grad, x2.grad.to_dense())
857
def test_avg_pool2d(self):
858
N = torch.randint(3, 10, (1,)).item()
859
C = torch.randint(3, 10, (1,)).item()
860
x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
861
self._test_avg_pool_base(dim=2, input=x)
863
def test_avg_pool3d(self):
864
N = torch.randint(3, 10, (1,)).item()
865
C = torch.randint(3, 10, (1,)).item()
866
x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
867
self._test_avg_pool_base(dim=3, input=x)
869
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
870
def _test_avg_pool_bf16_base(self, dim, input):
871
avg_module = {2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d}
872
x_bf16 = input.bfloat16()
873
for count_include_pad in [True, False]:
874
avg_pool = avg_module[dim](
878
count_include_pad=count_include_pad)
879
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
880
y = avg_pool(input.to_mkldnn()).to_dense()
881
y_bf16 = avg_pool(x_bf16.to_mkldnn()).to_dense(torch.float)
882
self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
884
msg = "mkldnn_avg_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
885
self.assertRaisesRegex(RuntimeError,
887
lambda: avg_pool(x_bf16.to_mkldnn()))
889
def test_avg_pool2d_bf16(self):
890
N = torch.randint(3, 10, (1,)).item()
891
C = torch.randint(3, 10, (1,)).item()
892
x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
893
self._test_avg_pool_bf16_base(dim=2, input=x)
895
def test_avg_pool3d_bf16(self):
896
N = torch.randint(3, 10, (1,)).item()
897
C = torch.randint(3, 10, (1,)).item()
898
x = torch.randn(N, C, 64, 64, 64, dtype=torch.float32) * 10
899
self._test_avg_pool_bf16_base(dim=3, input=x)
901
def test_avg_pool2d_stride_none(self):
902
N = torch.randint(3, 10, (1,)).item()
903
C = torch.randint(3, 10, (1,)).item()
904
x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
906
for count_include_pad in [True, False]:
912
count_include_pad=count_include_pad)
918
count_include_pad=count_include_pad)
920
self.assertEqual(y1, y2.to_dense())
922
def test_adaptive_avg_pool2d(self):
923
N = torch.randint(3, 10, (1,)).item()
924
C = torch.randint(3, 10, (1,)).item()
925
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
927
adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
928
x1 = x.clone().requires_grad_()
929
x2 = x.clone().to_mkldnn().requires_grad_()
930
y1 = adaptive_avg_pool2d(x1)
931
y2 = adaptive_avg_pool2d(x2).to_dense()
938
self.assertEqual(y1, y2)
939
self.assertEqual(x1.grad, x2.grad.to_dense())
941
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
942
def test_adaptive_avg_pool2d_bf16(self):
943
N = torch.randint(3, 10, (1,)).item()
944
C = torch.randint(3, 10, (1,)).item()
945
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
947
x_bf16 = x.bfloat16()
948
adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
950
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
951
y = adaptive_avg_pool2d(x.to_mkldnn()).to_dense()
952
y_bf16 = adaptive_avg_pool2d(x.to_mkldnn()).to_dense(torch.float32)
953
self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
955
msg = "mkldnn_adaptive_avg_pool2d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
956
self.assertRaisesRegex(RuntimeError,
958
lambda: adaptive_avg_pool2d(x_bf16.to_mkldnn()))
960
def _test_batch_norm_base(self, dim, channels, input):
961
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
962
bn = bn_module[dim](channels).float().train(False)
963
mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
966
mkldnn_bn(input.to_mkldnn()).to_dense())
968
self._test_serialization(mkldnn_bn, (input.to_mkldnn(),))
969
self._test_tracing(mkldnn_bn, (input.to_mkldnn(),))
971
def _test_batch_norm_train_base(self, dim, channels, input):
972
# TODO: support 3d batchnorm training.
973
bn_module = {2 : torch.nn.BatchNorm2d}
974
# TODO: support none affine.
975
options = itertools.product([True], [True, False])
976
for affine, track_running_stats in options:
978
num_features=channels,
980
track_running_stats=track_running_stats).float().train(True)
981
mkldnn_bn = copy.deepcopy(bn)
982
x1 = input.clone().requires_grad_()
983
x2 = input.clone().to_mkldnn().requires_grad_()
985
y2 = mkldnn_bn(x2).to_dense()
990
self.assertEqual(y1, y2)
991
self.assertEqual(x1.grad, x2.grad.to_dense())
992
self.assertEqual(bn.weight.grad, mkldnn_bn.weight.grad, rtol=1e-3, atol=1e-3)
993
if track_running_stats:
994
self.assertEqual(bn.running_mean, mkldnn_bn.running_mean)
995
self.assertEqual(bn.running_var, mkldnn_bn.running_var, rtol=1e-5, atol=1e-5)
997
def test_batch_norm_2d(self):
998
N = torch.randint(3, 10, (1,)).item()
999
C = torch.randint(3, 100, (1,)).item()
1000
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1001
self._test_batch_norm_base(dim=2, channels=C, input=x)
1002
self._test_batch_norm_train_base(dim=2, channels=C, input=x)
1004
def test_batch_norm_3d(self):
1005
N = torch.randint(3, 10, (1,)).item()
1006
C = torch.randint(3, 100, (1,)).item()
1007
x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1008
self._test_batch_norm_base(dim=3, channels=C, input=x)
1010
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
1011
def _test_batch_norm_bf16_base(self, dim, channels, input):
1012
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
1013
x_bf16 = input.bfloat16()
1014
# TODO: support training
1015
for train in [False]:
1016
bn = bn_module[dim](channels).float().train(train)
1017
mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
1018
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
1019
y = bn(input.to_mkldnn().to_dense())
1020
y_bf16 = bn(input.to_mkldnn().to_dense(torch.float))
1021
self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
1023
msg = "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
1024
self.assertRaisesRegex(RuntimeError,
1026
lambda: bn(x_bf16.to_mkldnn()))
1028
def test_batch_norm_2d_bf16(self):
1029
N = torch.randint(3, 10, (1,)).item()
1030
C = torch.randint(3, 100, (1,)).item()
1031
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1032
self._test_batch_norm_bf16_base(dim=2, channels=C, input=x)
1034
def test_batch_norm_3d_bf16(self):
1035
N = torch.randint(3, 10, (1,)).item()
1036
C = torch.randint(3, 100, (1,)).item()
1037
x = torch.randn(N, C, 30, 30, 30, dtype=torch.float32) * 10
1038
self._test_batch_norm_bf16_base(dim=3, channels=C, input=x)
1041
N = torch.randint(3, 10, (1,)).item()
1042
C = torch.randint(3, 100, (1,)).item()
1043
alpha = torch.randn(1, dtype=torch.float32).item()
1045
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1046
y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1053
(mx + my).to_dense())
1056
torch.add(x, y, alpha=alpha),
1057
torch.add(mx, my, alpha=alpha).to_dense())
1062
self.assertEqual(x, mx.to_dense())
1066
mkldnn_out = out.to_mkldnn()
1067
torch.add(x, y, alpha=alpha, out=out)
1068
torch.add(mx, my, alpha=alpha, out=mkldnn_out)
1069
self.assertEqual(out, mkldnn_out.to_dense())
1071
# add_out inplace case: first input
1072
torch.add(x, y, alpha=alpha, out=x)
1073
torch.add(mx, my, alpha=alpha, out=mx)
1074
self.assertEqual(x, mx.to_dense())
1076
# add_out inplace case: second input
1077
torch.add(x, y, alpha=alpha, out=y)
1078
torch.add(mx, my, alpha=alpha, out=my)
1079
self.assertEqual(y, my.to_dense())
1082
N = torch.randint(3, 10, (1,)).item()
1083
C = torch.randint(3, 100, (1,)).item()
1084
value = torch.randn(1, dtype=torch.float32).item()
1086
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1087
y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
1094
(mx * my).to_dense())
1098
(mx * value).to_dense())
1102
torch.mul(mx, my).to_dense())
1105
torch.mul(x, value),
1106
torch.mul(mx, value).to_dense())
1111
self.assertEqual(x, mx.to_dense())
1115
self.assertEqual(x, mx.to_dense())
1119
mkldnn_out = out.to_mkldnn()
1120
torch.mul(x, y, out=out)
1121
torch.mul(mx, my, out=mkldnn_out)
1122
self.assertEqual(out, mkldnn_out.to_dense())
1125
mkldnn_out = out.to_mkldnn()
1126
torch.mul(x, value, out=out)
1127
torch.mul(mx, value, out=mkldnn_out)
1128
self.assertEqual(out, mkldnn_out.to_dense())
1130
def test_0_dimension_tensor(self):
1131
x = torch.rand([20, 20, 1, 1], dtype=torch.float)
1132
y = torch.rand([20, 20, 0, 1], dtype=torch.float)
1134
# unary ops work without modification
1135
out_relu = torch.relu(y)
1136
out_relu_mkldnn = torch.relu(y.to_mkldnn()).to_dense()
1137
self.assertEqual(out_relu, out_relu_mkldnn)
1140
out_mul_mkldnn = (x.to_mkldnn() * y.to_mkldnn()).to_dense()
1141
self.assertEqual(out_mul, out_mul_mkldnn)
1144
out_add_mkldnn = (x.to_mkldnn() + y.to_mkldnn()).to_dense()
1145
self.assertEqual(out_add, out_add_mkldnn)
1147
x.requires_grad_(True)
1148
y.requires_grad_(True)
1149
with self.assertRaisesRegex(RuntimeError, "0-dimension Tensor in training"):
1150
x.to_mkldnn() + y.to_mkldnn()
1152
with self.assertRaisesRegex(RuntimeError, "must match"):
1153
torch.rand([5]).to_mkldnn() + torch.rand([0]).to_mkldnn()
1156
m = torch.nn.Conv2d(C, C, 3)
1157
x = torch.randn(0, C, C, 8, dtype=torch.float)
1159
out_mkldnn = mkldnn_utils.to_mkldnn(m)(x)
1160
self.assertEqual(out_eager, out_mkldnn)
1162
def test_view(self):
1163
x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1164
self.assertRaisesRegex(RuntimeError,
1165
"Change to use reshape",
1166
lambda: x.view(x.size(0), -1))
1168
def test_reshape(self):
1169
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1170
size = (x.size(0), -1)
1174
x.to_mkldnn().reshape(size).to_dense(),
1176
# test whether share same memory for plain format tensor
1178
z = y.reshape(size).add_(y.reshape(size))
1180
y.reshape(size).to_dense(),
1184
def test_reshape_blocked_format(self):
1185
# construct an mkldnn blocked tensor with mkldnn conv2d
1187
m = mkldnn_utils.to_mkldnn(torch.nn.Conv2d(C, C, 3))
1188
x = torch.randn(1, C, 8, 8).to_mkldnn()
1190
# mkldnn tensor w/ blocked format
1192
# aten tensor w/ plain format
1193
y_plain = y_block.to_dense()
1195
y_block_reshape = y_block.reshape(C, -1)
1196
y_plain_reshape = y_plain.reshape(C, -1)
1198
self.assertEqual(y_plain_reshape, y_block_reshape.to_dense())
1200
def test_reshape_backward(self):
1201
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1202
size = (x.size(0), -1)
1204
x1 = x.clone().requires_grad_()
1205
x2 = x.clone().to_mkldnn().requires_grad_()
1207
out_features = torch.randint(3, 100, (1,)).item()
1208
linear = torch.nn.Linear(in_features, out_features).float()
1210
y1 = linear(x1.reshape(size)).sum()
1211
y2 = linear(x2.reshape(size).to_dense()).sum()
1214
self.assertEqual(x1.grad, x2.grad.to_dense())
1216
def test_clone(self):
1217
x = torch.randn(4, 5, dtype=torch.float32) * 10
1220
x.to_mkldnn().clone().to_dense(),
1222
# test whether share same memory
1224
z = y.clone().add_(y)
1225
self.assertNotEqual(
1230
def test_transpose(self):
1231
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1232
for dim1 in range(x.ndim):
1233
for dim2 in range(x.ndim):
1235
x.transpose(dim1, dim2),
1236
x.to_mkldnn().transpose(dim1, dim2).to_dense(),
1239
def test_transpose_invalid_dime(self):
1240
x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
1241
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1242
torch._mkldnn_transpose(x, 0, 12)
1244
def test_linear_non_contiguous_weight(self):
1245
in_features = torch.randint(3, 10, (1,)).item()
1246
out_features = torch.randint(3, 100, (1,)).item()
1247
x = torch.randn(3, in_features, dtype=torch.float32) * 10
1248
w = torch.randn(in_features, out_features, dtype=torch.float32)
1249
for bias in [True, False]:
1250
x1 = x.clone().requires_grad_()
1251
x2 = x.clone().to_mkldnn().requires_grad_()
1252
linear = torch.nn.Linear(in_features, out_features).float()
1253
linear.weight = torch.nn.Parameter(w.t())
1254
mkldnn_linear = copy.deepcopy(linear)
1255
y1 = linear(x1).sum()
1256
y2 = mkldnn_linear(x2).to_dense().sum()
1259
self.assertEqual(x1.grad, x2.grad.to_dense())
1260
self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1262
self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1264
def test_linear(self):
1265
in_features = torch.randint(3, 10, (1,)).item()
1266
out_features = torch.randint(3, 100, (1,)).item()
1267
x = torch.randn(3, in_features, dtype=torch.float32) * 10
1269
for bias in [True, False]:
1270
linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1271
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1274
mkldnn_linear(x.to_mkldnn()).to_dense())
1276
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
1277
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
1279
def test_linear_backward(self):
1280
in_features = torch.randint(3, 10, (1,)).item()
1281
out_features = torch.randint(3, 100, (1,)).item()
1282
x = torch.randn(3, in_features, dtype=torch.float32) * 10
1283
for bias in [True, False]:
1284
x1 = x.clone().requires_grad_()
1285
x2 = x.clone().to_mkldnn().requires_grad_()
1286
linear = torch.nn.Linear(in_features, out_features).float()
1287
mkldnn_linear = copy.deepcopy(linear)
1288
y1 = linear(x1).sum()
1289
y2 = mkldnn_linear(x2).to_dense().sum()
1292
self.assertEqual(x1.grad, x2.grad.to_dense())
1293
self.assertEqual(linear.weight.grad, mkldnn_linear.weight.grad)
1295
self.assertEqual(linear.bias.grad, mkldnn_linear.bias.grad)
1297
@dtypes(torch.float16, torch.bfloat16)
1298
def test_linear_lowp(self, dtype):
1299
in_features = torch.randint(3, 10, (1,)).item()
1300
out_features = torch.randint(3, 100, (1,)).item()
1301
x = torch.randn(3, in_features, dtype=torch.float32) * 10
1302
x_lowp = x.to(dtype=dtype)
1304
for bias in [True, False]:
1305
linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
1306
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
1307
mkldnn_linear_lowp = mkldnn_utils.to_mkldnn(
1308
copy.deepcopy(linear), dtype
1311
torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1312
torch.half: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1314
if lowp_support[dtype]():
1315
y = mkldnn_linear(x.to_mkldnn()).to_dense()
1316
y_lowp = mkldnn_linear_lowp(x_lowp.to_mkldnn()).to_dense(
1319
if dtype == torch.bfloat16:
1320
self.assertEqual(y, y_lowp, atol=1e-1, rtol=1e-3)
1322
self.assertEqual(y, y_lowp, atol=5e-3, rtol=1e-3)
1325
torch.bfloat16: r"bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq",
1326
torch.half: r"fp16 path needs the cpu support avx_ne_convert or avx512_fp16",
1328
self.assertRaisesRegex(
1331
lambda: mkldnn_linear_lowp(x_lowp.to_mkldnn()),
1334
def test_softmax(self):
1335
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
1336
for dim in range(x.ndim):
1337
softmax = torch.nn.Softmax(dim=dim)
1340
softmax(x.to_mkldnn()).to_dense())
1342
def test_sigmoid(self):
1343
x = torch.randn(4, 5, dtype=torch.float32) * 10
1344
mkldnn_x = x.to_mkldnn()
1347
torch.sigmoid(mkldnn_x).to_dense(),
1351
torch.sigmoid_(mkldnn_x)
1352
self.assertEqual(x, mkldnn_x.to_dense())
1354
def test_tanh(self):
1355
x = torch.randn(4, 5, dtype=torch.float32) * 10
1356
mkldnn_x = x.to_mkldnn()
1359
torch.tanh(mkldnn_x).to_dense(),
1363
torch.tanh_(mkldnn_x)
1364
self.assertEqual(x, mkldnn_x.to_dense())
1366
def _test_serialization(self, module, inputs):
1367
with TemporaryFileName() as fname:
1368
torch.jit.save(module, fname)
1369
loaded = torch.jit.load(fname)
1371
module(*inputs).to_dense(),
1372
loaded(*inputs).to_dense())
1374
def _test_tracing(self, module, inputs):
1375
traced = torch.jit.trace(module, inputs)
1377
module(*inputs).to_dense(),
1378
traced(*inputs).to_dense())
1380
def test_set_data_tensorimpl_type(self):
1381
# Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl
1382
# of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`.
1383
x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu'))
1384
x_mkldnn = x.to_mkldnn()
1385
with self.assertRaisesRegex(RuntimeError, 'incompatible tensor type'):
1388
def test_empty(self):
1389
x1 = torch.empty(4, 5, 2, 3, dtype=torch.float32)
1390
x2 = torch.empty(4, 5, 2, 3, dtype=torch.float32, layout=torch._mkldnn)
1391
self.assertEqual(x1.size(), x2.to_dense().size())
1392
self.assertEqual(x1.dtype, x2.to_dense().dtype)
1394
def test_zero_(self):
1395
x1 = torch.randn(4, 5, dtype=torch.float32) * 10
1396
x2 = x1.clone().to_mkldnn()
1399
x2.zero_().to_dense(),
1402
def test_is_mkldnn(self):
1403
x = torch.randn(1, dtype=torch.float32)
1404
self.assertFalse(x.is_mkldnn)
1405
self.assertTrue(x.to_mkldnn().is_mkldnn)
1407
# legacy constructor/new doesn't support mkldnn tensors
1408
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992")
1409
def test_legacy_new_failure(self):
1410
x = torch.randn(1, dtype=torch.float32)
1411
x_mkldnn = x.to_mkldnn()
1412
self.assertRaises(RuntimeError, lambda: x_mkldnn.new(device='cpu'))
1413
self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x.storage()))
1414
self.assertRaises(RuntimeError, lambda: x_mkldnn.new(x))
1415
self.assertRaises(RuntimeError, lambda: x_mkldnn.new(torch.Size([2, 3])))
1416
self.assertRaises(RuntimeError, lambda: x_mkldnn.new([6]))
1418
def test_is_mkldnn_jit(self):
1419
class EnsureMkldnn(torch.jit.ScriptModule):
1420
@torch.jit.script_method
1421
def forward(self, x):
1427
x = torch.randn(1, dtype=torch.float32)
1428
self.assertTrue(m(x).is_mkldnn)
1429
self.assertTrue(m(x.to_mkldnn()).is_mkldnn)
1431
def _test_imagenet_model(self, model):
1432
model = model.train(False).float()
1433
mkldnn_model = mkldnn_utils.to_mkldnn(copy.deepcopy(model))
1434
x = torch.randn(1, 3, 224, 224, dtype=torch.float32)
1435
with torch.no_grad():
1438
mkldnn_model(x.to_mkldnn()).to_dense(),
1441
@skipIfNoTorchVision
1442
def test_resnet18(self):
1443
model = torchvision.models.resnet.resnet18(weights=None)
1444
self._test_imagenet_model(model)
1446
@skipIfNoTorchVision
1447
def test_resnext50_32x4d(self):
1448
model = torchvision.models.resnet.resnext50_32x4d(weights=None)
1449
self._test_imagenet_model(model)
1451
def _lstm_params_list(self):
1453
"input_size": [1, 5],
1454
"hidden_size": [5, 16],
1455
"num_layers": [1, 3],
1456
"bidirectional": [False, True],
1457
"bias": [False, True],
1458
"batch_first": [False, True],
1459
"dropout": [0, 0.4, 0.7, 1],
1460
"batch_size": [1, 2],
1462
"training": [False, True]
1465
params_list = list(params_dict.values())
1468
def _cast_dtype(self, input, bf16):
1470
input = input.to(torch.bfloat16)
1473
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
1474
def test_lstm(self):
1476
torch.manual_seed(seed)
1478
params_list = self._lstm_params_list()
1480
bf16 = True if dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported() else False
1486
for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \
1487
in itertools.product(*params_list):
1488
num_directions = 2 if bidirectional else 1
1490
input = torch.randn(batch_size, seq_len, input_size, dtype=torch.float32)
1492
input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32)
1493
h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1494
c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
1496
model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional,
1497
bias=bias, dropout=dropout, batch_first=batch_first).float()
1498
model.train() if training else model.eval()
1499
input1 = input.clone().requires_grad_(training)
1500
input2 = input.clone().requires_grad_(training)
1502
h1 = h.clone().requires_grad_(training)
1503
h2 = h.clone().requires_grad_(training)
1504
c1 = c.clone().requires_grad_(training)
1505
c2 = c.clone().requires_grad_(training)
1507
model1 = copy.deepcopy(model)
1508
model2 = copy.deepcopy(model)
1509
with torch.cpu.amp.autocast(enabled=bf16, dtype=torch.bfloat16), torch.no_grad() if not training else nullcontext():
1510
with torch.backends.mkldnn.flags(enabled=False):
1511
torch.manual_seed(seed)
1512
output1, (hn1, cn1) = self._cast_dtype(model1, bf16)(self._cast_dtype(input1, bf16),
1513
(self._cast_dtype(h1, bf16),
1514
self._cast_dtype(c1, bf16)))
1516
torch.manual_seed(seed)
1517
output2, (hn2, cn2) = model2(input2, (h2, c2))
1518
self.assertEqual(output1, output2, rtol=rtol, atol=atol)
1519
self.assertEqual(hn1, hn2, rtol=rtol, atol=atol)
1520
self.assertEqual(cn1, cn2, rtol=rtol, atol=atol)
1523
with torch.backends.mkldnn.flags(enabled=False):
1524
torch.manual_seed(seed)
1525
output1.sum().backward(retain_graph=True)
1527
torch.manual_seed(seed)
1528
output2.sum().backward(retain_graph=True)
1530
self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol)
1531
for name, para in model1.named_parameters():
1532
self.assertEqual(para, self._cast_dtype(getattr(model2, name), bf16))
1533
self.assertEqual(para.grad, self._cast_dtype(getattr(model2, name).grad, bf16), rtol=rtol, atol=atol)
1535
with torch.backends.mkldnn.flags(enabled=False):
1536
torch.manual_seed(seed)
1537
hn1.sum().backward(retain_graph=True)
1538
torch.manual_seed(seed)
1539
hn2.sum().backward(retain_graph=True)
1540
self.assertEqual(h1.grad, h2.grad, rtol=rtol, atol=atol)
1542
with torch.backends.mkldnn.flags(enabled=False):
1543
torch.manual_seed(seed)
1544
cn1.sum().backward(retain_graph=True)
1545
torch.manual_seed(seed)
1546
cn2.sum().backward(retain_graph=True)
1547
self.assertEqual(c1.grad, c2.grad, rtol=rtol, atol=atol)
1549
@dtypes(torch.float16, torch.bfloat16)
1550
def test_matmul_lower_precision(self, dtype):
1552
torch.bfloat16: torch.ops.mkldnn._is_mkldnn_bf16_supported,
1553
torch.float16: torch.ops.mkldnn._is_mkldnn_fp16_supported,
1556
def common(self, shape1, shape2, op, dtype):
1557
a = torch.randn(shape1, dtype=dtype)
1559
b = torch.randn(shape2, dtype=dtype)
1563
y_ref = op(a_ref, b_ref)
1564
self.assertEqual(y, y_ref, exact_dtype=False)
1566
if support_check[dtype]():
1567
a1 = torch.randn([64, 1, 33], dtype=dtype)
1568
# a2 is contiguous tensor but it's strides
1569
# is not default contiguous strides.
1570
a2 = torch.as_strided(a1.clone(), [64, 1, 33], [33, 3, 1])
1571
self.assertTrue(a2.is_contiguous())
1572
b = torch.randn(64, 33, 256).to(dtype=dtype)
1573
y1 = torch.ops.aten.bmm(a1, b)
1574
y2 = torch.bmm(a2, b)
1575
self.assertEqual(y1, y2)
1577
for shape1, shape2, op in [
1578
((33, 77), (77, 22), torch.matmul),
1579
((128, 256), (256, 10), torch.matmul),
1580
((7, 300), (300, 3), torch.matmul),
1581
((1, 100), (100, 60), torch.matmul),
1582
((100, 1), (1, 100), torch.matmul),
1583
((20, 54, 78), (20, 78, 10), torch.bmm),
1584
((1, 300, 1), (1, 1, 300), torch.bmm),
1586
common(self, shape1, shape2, op, dtype)
1589
instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))
1591
if __name__ == '__main__':