intel-extension-for-pytorch

Форк
0
/
test_instance_norm.py 
123 строки · 4.1 Кб
1
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
import torch
3

4
# import unittest
5
from common_utils import TestCase
6
from torch.nn import InstanceNorm2d, InstanceNorm3d, BatchNorm2d, BatchNorm3d
7

8
bn_m = {2: BatchNorm2d, 3: BatchNorm3d}
9
inst_m = {2: InstanceNorm2d, 3: InstanceNorm3d}
10

11

12
class InstanceNormTester(TestCase):
13
    def test_instance_norm(self):
14
        for dim in [2, 3]:
15
            batch = 10
16
            channel = 100
17

18
            input_size = [batch, channel]
19
            bn_size = [1, batch * channel]
20

21
            if dim == 2:
22
                memory_format = torch.channels_last
23
            else:
24
                memory_format = torch.channels_last_3d
25

26
            if dim == 2:
27
                input_size += [45, 35]
28
                bn_size += [45, 35]
29
            if dim == 3:
30
                input_size += [45, 35, 100]
31
                bn_size += [45, 35, 100]
32

33
            input = torch.randn(input_size)
34
            x = input.clone().detach().requires_grad_()
35
            x1 = input.clone().detach().requires_grad_()
36
            x1r = x1.reshape(bn_size)
37

38
            m = inst_m[dim](channel, affine=True)
39
            m1 = bn_m[dim](batch * channel, affine=True)
40

41
            y = m(x)
42
            y1 = m1(x1r).reshape_as(x1)
43
            self.assertTrue(y.dtype == torch.float32)
44
            self.assertEqual(y, y1)
45

46
            # backward
47
            y.mean().backward()
48
            y1.mean().backward()
49
            self.assertTrue(x.grad.dtype == torch.float32)
50
            self.assertEqual(x.grad, x1.grad)
51

52
            # test channels last
53
            x2 = input.clone().detach().to(memory_format=memory_format).requires_grad_()
54
            y2 = m(x2)
55
            self.assertTrue(y2.dtype == torch.float32)
56
            self.assertEqual(y2, y1)
57
            self.assertTrue(y2.is_contiguous(memory_format=torch.contiguous_format))
58

59
            y2.mean().backward()
60
            self.assertTrue(x2.grad.dtype == torch.float32)
61
            self.assertEqual(x2.grad, x1.grad)
62
            self.assertTrue(x2.grad.is_contiguous(memory_format=memory_format))
63

64
    def test_instance_norm_bfloat16(self):
65
        for dim in [2, 3]:
66
            batch = 10
67
            channel = 100
68

69
            input_size = [batch, channel]
70
            bn_size = [1, batch * channel]
71

72
            if dim == 2:
73
                memory_format = torch.channels_last
74
            else:
75
                memory_format = torch.channels_last_3d
76

77
            if dim == 2:
78
                input_size += [45, 35]
79
                bn_size += [45, 35]
80
            if dim == 3:
81
                input_size += [45, 35, 100]
82
                bn_size += [45, 35, 100]
83

84
            m = inst_m[dim](channel, affine=True)
85
            m1 = bn_m[dim](batch * channel, affine=True)
86

87
            with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
88
                input = torch.randn(input_size).bfloat16()
89
                x = input.clone().detach().requires_grad_()
90
                x1 = input.clone().detach().requires_grad_()
91
                x1r = x1.reshape(bn_size)
92

93
                y = m(x)
94
                y1 = m1(x1r).reshape_as(x1)
95
                self.assertTrue(y.dtype == torch.bfloat16)
96
                self.assertEqual(y, y1, prec=0.1)
97

98
                # backward
99
                y.mean().backward()
100
                y1.mean().backward()
101
                self.assertTrue(x.grad.dtype == torch.bfloat16)
102
                self.assertEqual(x.grad, x1.grad)
103

104
                # test channels last
105
                x2 = (
106
                    input.clone()
107
                    .detach()
108
                    .to(memory_format=memory_format)
109
                    .requires_grad_()
110
                )
111
                y2 = m(x2)
112
                self.assertTrue(y2.dtype == torch.bfloat16)
113
                self.assertTrue(y2.is_contiguous(memory_format=torch.contiguous_format))
114
                self.assertEqual(y2, y1, prec=0.1)
115

116
                y2.mean().backward()
117
                self.assertTrue(x2.grad.dtype == torch.bfloat16)
118
                self.assertTrue(x2.grad.is_contiguous(memory_format=memory_format))
119
                self.assertEqual(x2.grad, x1.grad)
120

121

122
# if __name__ == "__main__":
123
#     test = unittest.main()
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.