intel-extension-for-pytorch
281 строка · 10.8 Кб
1import unittest2from common_utils import TestCase3import torch4import torch.nn as nn5import intel_extension_for_pytorch as ipex6from intel_extension_for_pytorch.utils.channels_last_1d import (7is_contiguous_channels_last_1d,8)
9
10try:11import torchvision12
13HAS_TORCHVISION = True14except ImportError:15HAS_TORCHVISION = False16skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")17
18
19class TestAutoChannelsLast(TestCase):20def _get_covnNd(self, dim):21class ConvNd(nn.Module):22def __init__(self, dim):23super(ConvNd, self).__init__()24if dim == 1:25self.conv = nn.Conv1d(16, 33, 3)26elif dim == 2:27self.conv = nn.Conv2d(16, 33, 3)28elif dim == 3:29self.conv = nn.Conv3d(16, 33, 3)30
31def forward(self, x):32x = self.conv(x)33return x34
35model = ConvNd(dim=dim)36return model37
38def _get_sequential_conv2d(self):39class Conv2d(nn.Module):40def __init__(self):41super(Conv2d, self).__init__()42self.conv1 = nn.Conv2d(16, 33, 3)43self.conv2 = nn.Conv2d(33, 33, 3)44
45def forward(self, x):46x = self.conv1(x)47x = self.conv2(x)48return x49
50model = Conv2d()51return model52
53def _get_covnNd_relu(self, dim):54class ConvNdReLU(nn.Module):55def __init__(self, dim):56super(ConvNdReLU, self).__init__()57if dim == 1:58self.conv = nn.Conv1d(16, 33, 3)59elif dim == 2:60self.conv = nn.Conv2d(16, 33, 3)61elif dim == 3:62self.conv = nn.Conv3d(16, 33, 3)63self.relu = nn.ReLU()64
65def forward(self, x):66x = self.conv(x)67x = self.relu(x)68return x69
70model = ConvNdReLU(dim=dim)71return model72
73def _get_covnNd_linear(self, dim):74class ConvNdLinear(nn.Module):75def __init__(self, dim):76super(ConvNdLinear, self).__init__()77if dim == 1:78self.conv = nn.Conv1d(16, 33, 3)79elif dim == 2:80self.conv = nn.Conv2d(16, 33, 3)81elif dim == 3:82self.conv = nn.Conv3d(16, 33, 3)83self.linear = nn.Linear(48, 48)84
85def forward(self, x):86x = self.conv(x)87x = self.linear(x)88return x89
90model = ConvNdLinear(dim=dim)91return model92
93def _get_ipex_optimized_model_and_output_tensor(94self, model, dim, disable_auto_channels_last=False95):96model.eval()97
98if dim == 1:99x = torch.randn(20, 16, 50)100elif dim == 2:101x = torch.randn(20, 16, 50, 50)102elif dim == 3:103x = torch.randn(20, 16, 50, 50, 50)104
105if disable_auto_channels_last:106ipex.disable_auto_channels_last()107
108model = ipex.optimize(model, weights_prepack=False)109output = model(x)110return model, output111
112def get_channels_last_modules(self, module):113channels_last_modules = []114for name, param in module.named_parameters():115if param.is_contiguous(memory_format=torch.channels_last):116channels_last_modules.append(name)117return channels_last_modules118
119def test_auto_channels_last(self):120model = self._get_covnNd(dim=1)121model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=1)122self.assertTrue(is_contiguous_channels_last_1d(model.conv.weight))123self.assertTrue(is_contiguous_channels_last_1d(output))124
125model = self._get_covnNd(dim=2)126model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)127self.assertTrue(128model.conv.weight.is_contiguous(memory_format=torch.channels_last)129)130self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))131
132model = self._get_covnNd(dim=3)133model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=3)134self.assertTrue(135model.conv.weight.is_contiguous(memory_format=torch.channels_last_3d)136)137
138def test_disable_auto_channels_last(self):139model = self._get_covnNd(dim=1)140model, output = self._get_ipex_optimized_model_and_output_tensor(141model, dim=1, disable_auto_channels_last=True142)143self.assertTrue(144model.conv.weight.is_contiguous(memory_format=torch.contiguous_format)145)146self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))147
148model = self._get_covnNd(dim=2)149model, output = self._get_ipex_optimized_model_and_output_tensor(150model, dim=2, disable_auto_channels_last=True151)152self.assertTrue(153model.conv.weight.is_contiguous(memory_format=torch.contiguous_format)154)155self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))156
157model = self._get_covnNd(dim=3)158model, output = self._get_ipex_optimized_model_and_output_tensor(159model, dim=3, disable_auto_channels_last=True160)161self.assertTrue(162model.conv.weight.is_contiguous(memory_format=torch.contiguous_format)163)164self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))165
166def test_auto_channels_last_recursion(self):167model = self._get_sequential_conv2d()168model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)169
170self.assertTrue(171model.conv1.weight.is_contiguous(memory_format=torch.channels_last)172)173self.assertTrue(174model.conv2.weight.is_contiguous(memory_format=torch.channels_last)175)176self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))177
178def test_auto_channels_last_memory_format_propagation(self):179# memory format propagates through channels_last compatible layers180model = self._get_covnNd_relu(dim=1)181model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=1)182self.assertTrue(is_contiguous_channels_last_1d(model.conv.weight))183self.assertTrue(is_contiguous_channels_last_1d(output))184
185model = self._get_covnNd_relu(dim=2)186model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)187self.assertTrue(188model.conv.weight.is_contiguous(memory_format=torch.channels_last)189)190self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))191
192model = self._get_covnNd_relu(dim=3)193model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=3)194self.assertTrue(195model.conv.weight.is_contiguous(memory_format=torch.channels_last_3d)196)197
198# memory format reverts back to contiguous_format as linear is channels_last incompatible199model = self._get_covnNd_linear(dim=1)200model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=1)201self.assertTrue(is_contiguous_channels_last_1d(model.conv.weight))202self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))203
204model = self._get_covnNd_linear(dim=2)205model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=2)206self.assertTrue(207model.conv.weight.is_contiguous(memory_format=torch.channels_last)208)209self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))210
211model = self._get_covnNd_linear(dim=3)212model, output = self._get_ipex_optimized_model_and_output_tensor(model, dim=3)213self.assertTrue(214model.conv.weight.is_contiguous(memory_format=torch.channels_last_3d)215)216self.assertTrue(output.is_contiguous(memory_format=torch.contiguous_format))217
218@skipIfNoTorchVision219def test_auto_channels_last_resnet50(self):220model = torchvision.models.resnet.resnet50(pretrained=False)221model.eval()222
223# manual224model_channels_last = model.to(memory_format=torch.channels_last)225model_channels_last = self.get_channels_last_modules(model_channels_last)226
227# auto228model_ipex = ipex.optimize(model, weights_prepack=False)229model_ipex_channels_last_modules = self.get_channels_last_modules(model_ipex)230
231self.assertEqual(model_channels_last, model_ipex_channels_last_modules)232
233def test_auto_channels_last_for_int8(self):234conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}235
236class ConvNd(torch.nn.Module):237def __init__(self, dim, in_channels, out_channels, kernel_size, stride):238super(ConvNd, self).__init__()239self.conv = conv_module[dim](240in_channels, out_channels, kernel_size=kernel_size, stride=stride241)242
243def forward(self, x):244return self.conv(x)245
246def _test_conv(dim):247input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}248x_shape = (2, 3) + input_shapes[dim]249x = torch.randn(x_shape, dtype=torch.float32)250model = ConvNd(dim, 3, 4, 3, 2).eval()251qconfig = ipex.quantization.default_static_qconfig252prepared_model = ipex.quantization.prepare(model, qconfig, x)253# do calibration254y = prepared_model(x)255convert_model = ipex.quantization.convert(prepared_model)256with torch.no_grad():257traced_model = torch.jit.trace(convert_model, x)258traced_model = torch.jit.freeze(traced_model)259for _ in range(3):260y = traced_model(x)261return y262
263# disable auto channels_last264ipex.disable_auto_channels_last()265self.assertTrue(266_test_conv(2).is_contiguous(memory_format=torch.contiguous_format)267)268self.assertTrue(269_test_conv(3).is_contiguous(memory_format=torch.contiguous_format)270)271
272# enable auto channels_last273ipex.enable_auto_channels_last()274
275self.assertTrue(_test_conv(2).is_contiguous(memory_format=torch.channels_last))276# temporary disable before https://github.com/pytorch/pytorch/pull/74023 merged277# self.assertTrue(_test_conv(3).is_contiguous(memory_format = torch.channels_last_3d))278
279
280if __name__ == "__main__":281test = unittest.main()282