intel-extension-for-pytorch
613 строк · 23.5 Кб
1import torch2import torch.nn as nn3from torch.testing._internal.jit_utils import JitTestCase4import unittest5import torch.nn.functional as F6import time7
8
9def get_rand_seed():10return int(time.time() * 1000000000)11
12
13conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}14
15from typing import Dict, NamedTuple16
17
18class EltwiseFusionOp(NamedTuple):19ipex_eltwise_op: str20op_input_list: Dict = {}21
22
23unary_PyTorch_op_to_IPEX_op_map = {24torch.relu: EltwiseFusionOp("relu"),25torch.relu_: EltwiseFusionOp("relu_"),26torch.abs: EltwiseFusionOp("abs"),27torch.abs_: EltwiseFusionOp("abs_"),28torch.exp: EltwiseFusionOp("exp"),29torch.exp_: EltwiseFusionOp("exp_"),30nn.Hardswish(inplace=False): EltwiseFusionOp("hardswish"),31nn.Hardswish(inplace=True): EltwiseFusionOp("hardswish_"),32torch.log: EltwiseFusionOp("log"),33torch.log_: EltwiseFusionOp("log_"),34nn.Mish(inplace=False): EltwiseFusionOp("mish"),35nn.Mish(inplace=True): EltwiseFusionOp("mish_"),36torch.sigmoid: EltwiseFusionOp("sigmoid"),37torch.sigmoid_: EltwiseFusionOp("sigmoid_"),38torch.round: EltwiseFusionOp("round"),39torch.round_: EltwiseFusionOp("round_"),40torch.sqrt: EltwiseFusionOp("sqrt"),41torch.sqrt_: EltwiseFusionOp("sqrt_"),42torch.square: EltwiseFusionOp("square"),43torch.square_: EltwiseFusionOp("square_"),44torch.tanh: EltwiseFusionOp("tanh"),45torch.tanh_: EltwiseFusionOp("tanh_"),46nn.SiLU(inplace=False): EltwiseFusionOp("silu"),47nn.SiLU(inplace=True): EltwiseFusionOp("silu_"),48nn.Hardsigmoid(inplace=False): EltwiseFusionOp("hardsigmoid"),49nn.Hardsigmoid(inplace=True): EltwiseFusionOp("hardsigmoid_"),50}
51
52non_unary_PyTorch_op_to_IPEX_op_map = {53torch.clamp: EltwiseFusionOp("clamp", op_input_list={"min": -2, "max": 3}),54torch.clamp_: EltwiseFusionOp("clamp_", op_input_list={"min": -2, "max": 3}),55nn.GELU(approximate="none"): EltwiseFusionOp("gelu(none)"),56nn.GELU(approximate="tanh"): EltwiseFusionOp("gelu(tanh)"),57nn.ELU(inplace=False): EltwiseFusionOp("elu"),58nn.ELU(inplace=True): EltwiseFusionOp("elu_"),59torch.pow: EltwiseFusionOp("pow", op_input_list={"exponent": 2}),60lambda t: t.pow_(2): EltwiseFusionOp("pow_"),61nn.LeakyReLU(negative_slope=0.02, inplace=False): EltwiseFusionOp("leaky_relu"),62nn.LeakyReLU(negative_slope=0.02, inplace=True): EltwiseFusionOp("leaky_relu_"),63}
64
65
66class ConvEltwise(nn.Module):67def __init__(68self,69eltwise_fn,70dim,71in_channels,72out_channels,73kernel_size,74image_size,75**kwargs76):77super(ConvEltwise, self).__init__()78self.conv = conv_module[dim](in_channels, out_channels, kernel_size)79self.eltwise = eltwise_fn80self.kwargs = kwargs81
82def forward(self, x):83a = self.conv(x)84b = self.eltwise(a, **self.kwargs)85return b86
87
88class IPEXConvAdd(nn.Module):89def __init__(self, in_channels, out_channels, **kwargs):90super(IPEXConvAdd, self).__init__()91self.conv1 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)92self.conv2 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)93
94def forward(self, x):95a = self.conv1(x)96b = self.conv2(x)97return a.add_(b)98
99
100class IPEXConvAddRelu(nn.Module):101def __init__(self, in_channels, out_channels, **kwargs):102super(IPEXConvAddRelu, self).__init__()103self.conv1 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)104self.conv2 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)105
106def forward(self, x):107a = F.relu(self.conv1(x))108b = self.conv2(x)109return F.relu(a.add_(b), inplace=True)110
111
112class IPEXConvConvRelu(nn.Module):113def __init__(self, in_channels, out_channels, **kwargs):114super(IPEXConvConvRelu, self).__init__()115self.conv1 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)116self.conv2 = torch.nn.Conv2d(out_channels, out_channels, bias=False, **kwargs)117
118def forward(self, x):119res = self.conv1(x)120res = self.conv2(res)121return F.relu(res, inplace=True)122
123
124class IPEXConvSigmoidMul(nn.Module):125def __init__(self, in_channels, out_channels, **kwargs):126super(IPEXConvSigmoidMul, self).__init__()127self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)128
129def forward(self, x):130a = self.conv(x)131b = torch.sigmoid(a)132return a.mul_(b)133
134
135class LinearEltwise(nn.Module):136def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs):137super(LinearEltwise, self).__init__()138self.linear = nn.Linear(in_channels, out_channels, bias=bias)139self.eltwise = eltwise_fn140self.kwargs = kwargs141
142def forward(self, x):143a = self.linear(x)144a = a / 2145b = self.eltwise(a, **self.kwargs)146return b147
148
149class IPEXLinearAdd(nn.Module):150def __init__(self, in_channels, out_channels, bias):151super(IPEXLinearAdd, self).__init__()152self.linear1 = nn.Linear(in_channels, out_channels, bias=bias)153self.linear2 = nn.Linear(in_channels, out_channels, bias=bias)154
155def forward(self, x):156a = self.linear1(x)157b = self.linear2(x)158return a.add_(b)159
160
161class IPEXLinearAddRelu(nn.Module):162def __init__(self, in_channels, out_channels, bias):163super(IPEXLinearAddRelu, self).__init__()164self.linear = nn.Linear(in_channels, out_channels, bias=bias)165
166def forward(self, x):167a = F.relu(self.linear(x))168b = self.linear(x)169return F.relu(a.add_(b), inplace=True)170
171
172class IPEXLinearSigmoidMul(nn.Module):173def __init__(self, in_channels, out_channels, bias):174super(IPEXLinearSigmoidMul, self).__init__()175self.linear = nn.Linear(in_channels, out_channels, bias=bias)176
177def forward(self, x):178a = self.linear(x)179b = torch.sigmoid(a)180return a.mul_(b)181
182
183class IPEXMatmulDiv(nn.Module):184def __init__(self):185super(IPEXMatmulDiv, self).__init__()186seed = 2018187torch.manual_seed(seed)188
189def forward(self, x1, x2, x3):190return torch.matmul(x1, x2) / x3 + x3191
192
193class TestTE(JitTestCase):194def test_ipex_unary_conv_fusion(self, op_list=unary_PyTorch_op_to_IPEX_op_map):195old = torch._C._debug_get_fusion_group_inlining()196torch._C._debug_set_fusion_group_inlining(False)197dim = 2198out_channels = 16199in_channels = 3200kernel_size = 3201for eltwise in op_list:202rand_seed = int(get_rand_seed())203torch.manual_seed(rand_seed)204fusion_op = op_list[eltwise]205ipex_eltwise_op = fusion_op.ipex_eltwise_op206print("TEST conv2d+%s" % ipex_eltwise_op)207for use_channels_last in [0, 1]:208for batch_size, image_size in [[8, 20], [3, 256]]:209input_size = [batch_size, in_channels, image_size, image_size]210x = torch.randn(input_size)211te_model = ConvEltwise(212eltwise, dim, in_channels, out_channels, kernel_size, image_size213).eval()214if use_channels_last:215x = x.to(memory_format=torch.channels_last)216te_model = te_model.to(memory_format=torch.channels_last)217te_model_traced = torch.jit.trace(te_model, (x))218te_model_traced = torch.jit.freeze(te_model_traced)219te_model_traced(x)220# self.assertAllFused(te_model_traced.graph_for(x))221
222res_jit = te_model_traced(x)223res_imperative = te_model(x)224self.assertEqual(225res_jit,226res_imperative,227"{}, {}".format(res_jit, res_imperative),228)229torch._C._debug_set_fusion_group_inlining(old)230
231def test_ipex_non_unary_conv_fusion(232self, op_list=non_unary_PyTorch_op_to_IPEX_op_map233):234old = torch._C._debug_get_fusion_group_inlining()235torch._C._debug_set_fusion_group_inlining(False)236dim = 2237out_channels = 16238in_channels = 3239kernel_size = 3240for eltwise in op_list:241rand_seed = int(get_rand_seed())242torch.manual_seed(rand_seed)243fusion_op = op_list[eltwise]244ipex_eltwise_op = fusion_op.ipex_eltwise_op245print("TEST conv2d+%s" % ipex_eltwise_op)246for use_channels_last in [0, 1]:247for batch_size, image_size in [[8, 20], [3, 256]]:248input_size = [batch_size, in_channels, image_size, image_size]249x = torch.randn(input_size)250op_input_list = fusion_op.op_input_list251te_model = ConvEltwise(252eltwise,253dim,254in_channels,255out_channels,256kernel_size,257image_size,258**op_input_list259).eval()260if use_channels_last:261x = x.to(memory_format=torch.channels_last)262te_model = te_model.to(memory_format=torch.channels_last)263te_model_traced = torch.jit.trace(te_model, (x))264te_model_traced = torch.jit.freeze(te_model_traced)265te_model_traced(x)266# self.assertAllFused(te_model_traced.graph_for(x))267
268res_jit = te_model_traced(x)269res_imperative = te_model(x)270self.assertEqual(271res_jit,272res_imperative,273"{}, {}".format(res_jit, res_imperative),274)275torch._C._debug_set_fusion_group_inlining(old)276
277def test_ipex_conv_add(self):278old = torch._C._debug_get_fusion_group_inlining()279torch._C._debug_set_fusion_group_inlining(False)280print("TEST conv2d+add")281rand_seed = int(get_rand_seed())282torch.manual_seed(rand_seed)283for use_channels_last in [0, 1]:284te_model = IPEXConvAdd(3, 2, kernel_size=(3, 3)).eval()285x = torch.randn(1, 3, 10, 10)286if use_channels_last:287x = x.to(memory_format=torch.channels_last)288te_model = te_model.to(memory_format=torch.channels_last)289te_model_traced = torch.jit.trace(te_model, (x))290te_model_traced = torch.jit.freeze(te_model_traced)291te_model_traced(x)292# self.assertAllFused(te_model_traced.graph_for(x))293
294res_jit = te_model_traced(x)295res_imperative = te_model(x)296self.assertEqual(res_jit, res_imperative)297
298x = torch.randn(3, 3, 20, 20)299res_jit = te_model_traced(x)300res_imperative = te_model(x)301self.assertEqual(res_jit, res_imperative)302
303torch._C._debug_set_fusion_group_inlining(old)304
305def test_ipex_conv_add_relu(self):306old = torch._C._debug_get_fusion_group_inlining()307torch._C._debug_set_fusion_group_inlining(False)308print("TEST conv2d+add+relu")309rand_seed = int(get_rand_seed())310torch.manual_seed(rand_seed)311for use_channels_last in [0, 1]:312te_model = IPEXConvAddRelu(3, 2, kernel_size=(3, 3)).eval()313x = torch.randn(1, 3, 10, 10)314if use_channels_last:315x = x.to(memory_format=torch.channels_last)316te_model = te_model.to(memory_format=torch.channels_last)317te_model_traced = torch.jit.trace(te_model, (x))318te_model_traced = torch.jit.freeze(te_model_traced)319te_model_traced(x)320# self.assertAllFused(te_model_traced.graph_for(x))321
322res_jit = te_model_traced(x)323res_imperative = te_model(x)324self.assertEqual(res_jit, res_imperative)325
326x = torch.randn(3, 3, 20, 20)327res_jit = te_model_traced(x)328res_imperative = te_model(x)329self.assertEqual(res_jit, res_imperative)330
331torch._C._debug_set_fusion_group_inlining(old)332
333def test_ipex_conv_conv_relu(self):334old = torch._C._debug_get_fusion_group_inlining()335torch._C._debug_set_fusion_group_inlining(False)336print("TEST conv bottleneck")337rand_seed = int(get_rand_seed())338torch.manual_seed(rand_seed)339for use_channels_last in [0, 1]:340te_model = IPEXConvConvRelu(3, 10, kernel_size=(3, 3)).eval()341x = torch.randn(1, 3, 224, 224)342if use_channels_last:343x = x.to(memory_format=torch.channels_last)344te_model = te_model.to(memory_format=torch.channels_last)345te_model_traced = torch.jit.script(te_model)346te_model_traced = torch.jit.freeze(te_model_traced)347te_model_traced(x)348
349# self.assertAllFused(te_model_traced.graph_for(x))350
351res_jit = te_model_traced(x)352res_imperative = te_model(x)353self.assertEqual(res_jit, res_imperative)354
355x = torch.randn(3, 3, 500, 500)356res_jit = te_model_traced(x)357res_imperative = te_model(x)358self.assertEqual(res_jit, res_imperative)359
360torch._C._debug_set_fusion_group_inlining(old)361
362def test_ipex_conv_sigmoid_mul(self):363old = torch._C._debug_get_fusion_group_inlining()364torch._C._debug_set_fusion_group_inlining(False)365print("TEST conv2d+sigmoid+mul")366rand_seed = int(get_rand_seed())367torch.manual_seed(rand_seed)368for use_channels_last in [0, 1]:369te_model = IPEXConvSigmoidMul(3, 2, kernel_size=(3, 3)).eval()370x = torch.randn(1, 3, 10, 10)371if use_channels_last:372x = x.to(memory_format=torch.channels_last)373te_model = te_model.to(memory_format=torch.channels_last)374te_model_traced = torch.jit.trace(te_model, (x))375te_model_traced = torch.jit.freeze(te_model_traced)376te_model_traced(x)377# self.assertAllFused(te_model_traced.graph_for(x))378
379res_jit = te_model_traced(x)380res_imperative = te_model(x)381self.assertEqual(res_jit, res_imperative)382
383x = torch.randn(3, 3, 20, 20)384res_jit = te_model_traced(x)385res_imperative = te_model(x)386self.assertEqual(res_jit, res_imperative)387
388torch._C._debug_set_fusion_group_inlining(old)389
390def test_ipex_matmul_div(self):391print("TEST conv matmul+div")392te_matmul_div = IPEXMatmulDiv()393rand_seed = int(get_rand_seed())394torch.manual_seed(rand_seed)395x1 = torch.randn(5, 5)396x2 = torch.randn(5, 5)397x3 = torch.randn(5, 5)398te_matmul_div_traced = torch.jit.script(te_matmul_div).eval()399te_matmul_div_traced = torch.jit.freeze(te_matmul_div_traced)400te_matmul_div_traced(x1, x2, x3)401# self.assertAllFused(te_matmul_div_traced.graph_for(x1, x2, x3))402res_jit = te_matmul_div_traced(x1, x2, x3)403res_imperative = te_matmul_div(x1, x2, x3)404self.assertEqual(res_jit, res_imperative)405
406def test_ipex_unary_linear_fusion(self, op_list=unary_PyTorch_op_to_IPEX_op_map):407old = torch._C._debug_get_fusion_group_inlining()408torch._C._debug_set_fusion_group_inlining(False)409batch_size = 3410out_channels = 32411in_channels = 3412for eltwise in op_list:413rand_seed = int(get_rand_seed())414torch.manual_seed(rand_seed)415fusion_op = op_list[eltwise]416ipex_eltwise_op = fusion_op.ipex_eltwise_op417""" # Issue of "round"418The OP "round" in ideep has numeric issue when input is exactly 0.500,
419so we fix the seed here for "round".
420For example:
421x = torch.Tensor([0.500])
422ideep: 1.0 = torch.round(x)
423expected: 0.0 = torch.round(x)
424The seed to reproduce the failure: 1665593217573048320
425"""
426if "round" in ipex_eltwise_op:427torch.manual_seed(1665594679504775936)428print("TEST linear+%s" % ipex_eltwise_op)429for bias in [True, False]:430input_size = [batch_size, in_channels]431x = torch.randn(input_size)432# linear fusion only supports bf16433with torch.cpu.amp.autocast(434enabled=True, dtype=torch.bfloat16435), torch.no_grad():436te_model = LinearEltwise(437eltwise, in_channels, out_channels, bias438).eval()439te_model_traced = torch.jit.trace(te_model, (x))440te_model_traced = torch.jit.freeze(te_model_traced)441te_model_traced(x)442# self.assertAllFused(te_model_traced.graph_for(x))443
444res_jit = te_model_traced(x)445res_imperative = te_model(x)446self.assertEqual(447res_jit,448res_imperative,449rtol=0.02,450atol=0.01,451msg="{}, {}".format(res_jit, res_imperative),452)453torch._C._debug_set_fusion_group_inlining(old)454
455def test_ipex_non_unary_linear_fusion(456self, op_list=non_unary_PyTorch_op_to_IPEX_op_map457):458old = torch._C._debug_get_fusion_group_inlining()459torch._C._debug_set_fusion_group_inlining(False)460batch_size = 3461out_channels = 32462in_channels = 3463for eltwise in op_list:464rand_seed = int(get_rand_seed())465torch.manual_seed(rand_seed)466fusion_op = op_list[eltwise]467ipex_eltwise_op = fusion_op.ipex_eltwise_op468print("TEST linear+%s" % ipex_eltwise_op)469for bias in [True, False]:470input_size = [batch_size, in_channels]471x = torch.randn(input_size)472op_input_list = fusion_op.op_input_list473# linear fusion only supports bf16474with torch.cpu.amp.autocast(475enabled=True, dtype=torch.bfloat16476), torch.no_grad():477te_model = LinearEltwise(478eltwise, in_channels, out_channels, bias, **op_input_list479).eval()480te_model_traced = torch.jit.trace(te_model, (x))481te_model_traced = torch.jit.freeze(te_model_traced)482te_model_traced(x)483# self.assertAllFused(te_model_traced.graph_for(x))484
485res_jit = te_model_traced(x)486res_imperative = te_model(x)487self.assertEqual(488res_jit,489res_imperative,490rtol=0.02,491atol=0.01,492msg="{}, {}".format(res_jit, res_imperative),493)494torch._C._debug_set_fusion_group_inlining(old)495
496def test_ipex_linear_add(self):497old = torch._C._debug_get_fusion_group_inlining()498torch._C._debug_set_fusion_group_inlining(False)499print("TEST linear+add")500rand_seed = int(get_rand_seed())501torch.manual_seed(rand_seed)502for bias in [True, False]:503with torch.cpu.amp.autocast(504enabled=True, dtype=torch.bfloat16505), torch.no_grad():506te_model = IPEXLinearAdd(3, 32, bias).eval()507x = torch.randn(3, 3)508te_model_traced = torch.jit.trace(te_model, (x))509te_model_traced = torch.jit.freeze(te_model_traced)510te_model_traced(x)511# self.assertAllFused(te_model_traced.graph_for(x))512
513res_jit = te_model_traced(x)514res_imperative = te_model(x)515self.assertEqual(516res_jit,517res_imperative,518rtol=0.02,519atol=0.01,520msg="{}, {}".format(res_jit, res_imperative),521)522
523x = torch.randn(8, 3)524res_jit = te_model_traced(x)525res_imperative = te_model(x)526self.assertEqual(527res_jit,528res_imperative,529rtol=0.02,530atol=0.01,531msg="{}, {}".format(res_jit, res_imperative),532)533
534def test_ipex_linear_add_relu(self):535old = torch._C._debug_get_fusion_group_inlining()536torch._C._debug_set_fusion_group_inlining(False)537print("TEST linear+add+relu")538rand_seed = int(get_rand_seed())539torch.manual_seed(rand_seed)540for bias in [True, False]:541with torch.cpu.amp.autocast(542enabled=True, dtype=torch.bfloat16543), torch.no_grad():544te_model = IPEXLinearAddRelu(3, 32, bias).eval()545x = torch.randn(3, 3)546te_model_traced = torch.jit.trace(te_model, (x))547te_model_traced = torch.jit.freeze(te_model_traced)548te_model_traced(x)549# self.assertAllFused(te_model_traced.graph_for(x))550
551res_jit = te_model_traced(x)552res_imperative = te_model(x)553self.assertEqual(554res_jit,555res_imperative,556rtol=0.02,557atol=0.01,558msg="{}, {}".format(res_jit, res_imperative),559)560
561x = torch.randn(8, 3)562res_jit = te_model_traced(x)563res_imperative = te_model(x)564self.assertEqual(565res_jit,566res_imperative,567rtol=0.02,568atol=0.01,569msg="{}, {}".format(res_jit, res_imperative),570)571
572def test_ipex_linear_sigmoid_mul(self):573old = torch._C._debug_get_fusion_group_inlining()574torch._C._debug_set_fusion_group_inlining(False)575print("TEST linear+sigmoid+mul")576rand_seed = int(get_rand_seed())577torch.manual_seed(rand_seed)578for bias in [True, False]:579with torch.cpu.amp.autocast(580enabled=True, dtype=torch.bfloat16581), torch.no_grad():582te_model = IPEXLinearSigmoidMul(3, 32, bias).eval()583x = torch.randn(3, 3)584te_model_traced = torch.jit.trace(te_model, (x))585te_model_traced = torch.jit.freeze(te_model_traced)586te_model_traced(x)587# self.assertAllFused(te_model_traced.graph_for(x))588
589res_jit = te_model_traced(x)590res_imperative = te_model(x)591self.assertEqual(592res_jit,593res_imperative,594rtol=0.02,595atol=0.01,596msg="{}, {}".format(res_jit, res_imperative),597)598
599x = torch.randn(8, 3)600res_jit = te_model_traced(x)601res_imperative = te_model(x)602self.assertEqual(603res_jit,604res_imperative,605rtol=0.02,606atol=0.01,607msg="{}, {}".format(res_jit, res_imperative),608)609
610
611if __name__ == "__main__":612# ipex._C.enable_custom_op_2_nnc_fuser()613test = unittest.main()614