intel-extension-for-pytorch
123 строки · 4.1 Кб
1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2import torch
3
4# import unittest
5from common_utils import TestCase
6from torch.nn import InstanceNorm2d, InstanceNorm3d, BatchNorm2d, BatchNorm3d
7
8bn_m = {2: BatchNorm2d, 3: BatchNorm3d}
9inst_m = {2: InstanceNorm2d, 3: InstanceNorm3d}
10
11
12class InstanceNormTester(TestCase):
13def test_instance_norm(self):
14for dim in [2, 3]:
15batch = 10
16channel = 100
17
18input_size = [batch, channel]
19bn_size = [1, batch * channel]
20
21if dim == 2:
22memory_format = torch.channels_last
23else:
24memory_format = torch.channels_last_3d
25
26if dim == 2:
27input_size += [45, 35]
28bn_size += [45, 35]
29if dim == 3:
30input_size += [45, 35, 100]
31bn_size += [45, 35, 100]
32
33input = torch.randn(input_size)
34x = input.clone().detach().requires_grad_()
35x1 = input.clone().detach().requires_grad_()
36x1r = x1.reshape(bn_size)
37
38m = inst_m[dim](channel, affine=True)
39m1 = bn_m[dim](batch * channel, affine=True)
40
41y = m(x)
42y1 = m1(x1r).reshape_as(x1)
43self.assertTrue(y.dtype == torch.float32)
44self.assertEqual(y, y1)
45
46# backward
47y.mean().backward()
48y1.mean().backward()
49self.assertTrue(x.grad.dtype == torch.float32)
50self.assertEqual(x.grad, x1.grad)
51
52# test channels last
53x2 = input.clone().detach().to(memory_format=memory_format).requires_grad_()
54y2 = m(x2)
55self.assertTrue(y2.dtype == torch.float32)
56self.assertEqual(y2, y1)
57self.assertTrue(y2.is_contiguous(memory_format=torch.contiguous_format))
58
59y2.mean().backward()
60self.assertTrue(x2.grad.dtype == torch.float32)
61self.assertEqual(x2.grad, x1.grad)
62self.assertTrue(x2.grad.is_contiguous(memory_format=memory_format))
63
64def test_instance_norm_bfloat16(self):
65for dim in [2, 3]:
66batch = 10
67channel = 100
68
69input_size = [batch, channel]
70bn_size = [1, batch * channel]
71
72if dim == 2:
73memory_format = torch.channels_last
74else:
75memory_format = torch.channels_last_3d
76
77if dim == 2:
78input_size += [45, 35]
79bn_size += [45, 35]
80if dim == 3:
81input_size += [45, 35, 100]
82bn_size += [45, 35, 100]
83
84m = inst_m[dim](channel, affine=True)
85m1 = bn_m[dim](batch * channel, affine=True)
86
87with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
88input = torch.randn(input_size).bfloat16()
89x = input.clone().detach().requires_grad_()
90x1 = input.clone().detach().requires_grad_()
91x1r = x1.reshape(bn_size)
92
93y = m(x)
94y1 = m1(x1r).reshape_as(x1)
95self.assertTrue(y.dtype == torch.bfloat16)
96self.assertEqual(y, y1, prec=0.1)
97
98# backward
99y.mean().backward()
100y1.mean().backward()
101self.assertTrue(x.grad.dtype == torch.bfloat16)
102self.assertEqual(x.grad, x1.grad)
103
104# test channels last
105x2 = (
106input.clone()
107.detach()
108.to(memory_format=memory_format)
109.requires_grad_()
110)
111y2 = m(x2)
112self.assertTrue(y2.dtype == torch.bfloat16)
113self.assertTrue(y2.is_contiguous(memory_format=torch.contiguous_format))
114self.assertEqual(y2, y1, prec=0.1)
115
116y2.mean().backward()
117self.assertTrue(x2.grad.dtype == torch.bfloat16)
118self.assertTrue(x2.grad.is_contiguous(memory_format=memory_format))
119self.assertEqual(x2.grad, x1.grad)
120
121
122# if __name__ == "__main__":
123# test = unittest.main()
124