intel-extension-for-pytorch

Форк
0
/
test_auto_channels_last.py 
281 строка · 10.8 Кб
1
import unittest
2
from common_utils import TestCase
3
import torch
4
import torch.nn as nn
5
import intel_extension_for_pytorch as ipex
6
from intel_extension_for_pytorch.utils.channels_last_1d import (
7
    is_contiguous_channels_last_1d,
8
)
9

10
try:
11
    import torchvision
12

13
    HAS_TORCHVISION = True
14
except ImportError:
15
    HAS_TORCHVISION = False
16
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
17

18

19
class TestAutoChannelsLast(TestCase):
20
    def _get_covnNd(self, dim):
21
        class ConvNd(nn.Module):
22
            def __init__(self, dim):
23
                super(ConvNd, self).__init__()
24
                if dim == 1:
25
                    self.conv = nn.Conv1d(16, 33, 3)
26
                elif dim == 2:
27
                    self.conv = nn.Conv2d(16, 33, 3)
28
                elif dim == 3:
29
                    self.conv = nn.Conv3d(16, 33, 3)
30

31
            def forward(self, x):
32
                x = self.conv(x)
33
                return x
34

35
        model = ConvNd(dim=dim)
36
        return model
37

38
    def _get_sequential_conv2d(self):
39
        class Conv2d(nn.Module):
40
            def __init__(self):
41
                super(Conv2d, self).__init__()
42
                self.conv1 = nn.Conv2d(16, 33, 3)
43
                self.conv2 = nn.Conv2d(33, 33, 3)
44

45
            def forward(self, x):
46
                x = self.conv1(x)
47
                x = self.conv2(x)
48
                return x
49

50
        model = Conv2d()
51
        return model
52

53
    def _get_covnNd_relu(self, dim):
54
        class ConvNdReLU(nn.Module):
55
            def __init__(self, dim):
56
                super(ConvNdReLU, self).__init__()
57
                if dim == 1:
58
                    self.conv = nn.Conv1d(16, 33, 3)
59
                elif dim == 2:
60
                    self.conv = nn.Conv2d(16, 33, 3)
61
                elif dim == 3:
62
                    self.conv = nn.Conv3d(16, 33, 3)
63
                self.relu = nn.ReLU()
64

65
            def forward(self, x):
66
                x = self.conv(x)
67
                x = self.relu(x)
68
                return x
69

70
        model = ConvNdReLU(dim=dim)
71
        return model
72

73
    def _get_covnNd_linear(self, dim):
74
        class ConvNdLinear(nn.Module):
75
            def __init__(self, dim):
76
                super(ConvNdLinear, self).__init__()
77
                if dim == 1:
78
                    self.conv = nn.Conv1d(16, 33, 3)
79
                elif dim == 2:
80
                    self.conv = nn.Conv2d(16, 33, 3)
81
                elif dim == 3:
82
                    self.conv = nn.Conv3d(16, 33, 3)
83
                self.linear = nn.Linear(48, 48)
84

85
            def forward(self, x):
86
                x = self.conv(x)
87
                x = self.linear(x)
88
                return x
89

90
        model = ConvNdLinear(dim=dim)
91
        return model
92

93
    def _get_ipex_optimized_model_and_output_tensor(
94
        self, model, dim, disable_auto_channels_last=False
95
    ):
96
        model.eval()
97

98
        if dim == 1:
99
            x = torch.randn(20, 16, 50)
100
        elif dim == 2:
101
            x = torch.randn(20, 16, 50, 50)
102
        elif dim == 3:
103
            x = torch.randn(20, 16, 50, 50, 50)
104

105
        if disable_auto_channels_last:
106
            ipex.disable_auto_channels_last()
107

108
        model = ipex.optimize(model, weights_prepack=False)
109
        output = model(x)
110
        return model, output
111

112
    def get_channels_last_modules(self, module):
113
        channels_last_modules = []
114
        for name, param in module.named_parameters():
115
            if param.is_contiguous(memory_format=torch.channels_last):
116
                channels_last_modules.append(name)
117
        return channels_last_modules
118

119
    def test_auto_channels_last(self):
120
        model = self._get_covnNd(dim=1)
121
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=1)
122
        self.assertTrue(is_contiguous_channels_last_1d(model.conv.weight))
123
        self.assertTrue(is_contiguous_channels_last_1d(output))
124

125
        model = self._get_covnNd(dim=2)
126
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)
127
        self.assertTrue(
128
            model.conv.weight.is_contiguous(memory_format=torch.channels_last)
129
        )
130
        self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
131

132
        model = self._get_covnNd(dim=3)
133
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=3)
134
        self.assertTrue(
135
            model.conv.weight.is_contiguous(memory_format=torch.channels_last_3d)
136
        )
137

138
    def test_disable_auto_channels_last(self):
139
        model = self._get_covnNd(dim=1)
140
        model, output = self._get_ipex_optimized_model_and_output_tensor(
141
            model, dim=1, disable_auto_channels_last=True
142
        )
143
        self.assertTrue(
144
            model.conv.weight.is_contiguous(memory_format=torch.contiguous_format)
145
        )
146
        self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))
147

148
        model = self._get_covnNd(dim=2)
149
        model, output = self._get_ipex_optimized_model_and_output_tensor(
150
            model, dim=2, disable_auto_channels_last=True
151
        )
152
        self.assertTrue(
153
            model.conv.weight.is_contiguous(memory_format=torch.contiguous_format)
154
        )
155
        self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))
156

157
        model = self._get_covnNd(dim=3)
158
        model, output = self._get_ipex_optimized_model_and_output_tensor(
159
            model, dim=3, disable_auto_channels_last=True
160
        )
161
        self.assertTrue(
162
            model.conv.weight.is_contiguous(memory_format=torch.contiguous_format)
163
        )
164
        self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))
165

166
    def test_auto_channels_last_recursion(self):
167
        model = self._get_sequential_conv2d()
168
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)
169

170
        self.assertTrue(
171
            model.conv1.weight.is_contiguous(memory_format=torch.channels_last)
172
        )
173
        self.assertTrue(
174
            model.conv2.weight.is_contiguous(memory_format=torch.channels_last)
175
        )
176
        self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
177

178
    def test_auto_channels_last_memory_format_propagation(self):
179
        # memory format propagates through channels_last compatible layers
180
        model = self._get_covnNd_relu(dim=1)
181
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=1)
182
        self.assertTrue(is_contiguous_channels_last_1d(model.conv.weight))
183
        self.assertTrue(is_contiguous_channels_last_1d(output))
184

185
        model = self._get_covnNd_relu(dim=2)
186
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)
187
        self.assertTrue(
188
            model.conv.weight.is_contiguous(memory_format=torch.channels_last)
189
        )
190
        self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
191

192
        model = self._get_covnNd_relu(dim=3)
193
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=3)
194
        self.assertTrue(
195
            model.conv.weight.is_contiguous(memory_format=torch.channels_last_3d)
196
        )
197

198
        # memory format reverts back to contiguous_format as linear is channels_last incompatible
199
        model = self._get_covnNd_linear(dim=1)
200
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=1)
201
        self.assertTrue(is_contiguous_channels_last_1d(model.conv.weight))
202
        self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))
203

204
        model = self._get_covnNd_linear(dim=2)
205
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)
206
        self.assertTrue(
207
            model.conv.weight.is_contiguous(memory_format=torch.channels_last)
208
        )
209
        self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))
210

211
        model = self._get_covnNd_linear(dim=3)
212
        model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=3)
213
        self.assertTrue(
214
            model.conv.weight.is_contiguous(memory_format=torch.channels_last_3d)
215
        )
216
        self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))
217

218
    @skipIfNoTorchVision
219
    def test_auto_channels_last_resnet50(self):
220
        model = torchvision.models.resnet.resnet50(pretrained=False)
221
        model.eval()
222

223
        # manual
224
        model_channels_last = model.to(memory_format=torch.channels_last)
225
        model_channels_last = self.get_channels_last_modules(model_channels_last)
226

227
        # auto
228
        model_ipex = ipex.optimize(model, weights_prepack=False)
229
        model_ipex_channels_last_modules = self.get_channels_last_modules(model_ipex)
230

231
        self.assertEqual(model_channels_last, model_ipex_channels_last_modules)
232

233
    def test_auto_channels_last_for_int8(self):
234
        conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
235

236
        class ConvNd(torch.nn.Module):
237
            def __init__(self, dim, in_channels, out_channels, kernel_size, stride):
238
                super(ConvNd, self).__init__()
239
                self.conv = conv_module[dim](
240
                    in_channels, out_channels, kernel_size=kernel_size, stride=stride
241
                )
242

243
            def forward(self, x):
244
                return self.conv(x)
245

246
        def _test_conv(dim):
247
            input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
248
            x_shape = (2, 3) + input_shapes[dim]
249
            x = torch.randn(x_shape, dtype=torch.float32)
250
            model = ConvNd(dim, 3, 4, 3, 2).eval()
251
            qconfig = ipex.quantization.default_static_qconfig
252
            prepared_model = ipex.quantization.prepare(model, qconfig, x)
253
            # do calibration
254
            y = prepared_model(x)
255
            convert_model = ipex.quantization.convert(prepared_model)
256
            with torch.no_grad():
257
                traced_model = torch.jit.trace(convert_model, x)
258
                traced_model = torch.jit.freeze(traced_model)
259
                for _ in range(3):
260
                    y = traced_model(x)
261
            return y
262

263
        # disable auto channels_last
264
        ipex.disable_auto_channels_last()
265
        self.assertTrue(
266
            _test_conv(2).is_contiguous(memory_format=torch.contiguous_format)
267
        )
268
        self.assertTrue(
269
            _test_conv(3).is_contiguous(memory_format=torch.contiguous_format)
270
        )
271

272
        # enable auto channels_last
273
        ipex.enable_auto_channels_last()
274

275
        self.assertTrue(_test_conv(2).is_contiguous(memory_format=torch.channels_last))
276
        # temporary disable before https://github.com/pytorch/pytorch/pull/74023 merged
277
        # self.assertTrue(_test_conv(3).is_contiguous(memory_format = torch.channels_last_3d))
278

279

280
if __name__ == "__main__":
281
    test = unittest.main()
282

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.