intel-extension-for-pytorch
310 строк · 10.3 Кб
1import unittest2import itertools3import torch4import intel_extension_for_pytorch as ipex5from torch.testing._internal.common_utils import TestCase6import copy7from intel_extension_for_pytorch.cpu._auto_kernel_selection import (8_enable_tpp,9_disable_tpp,10)
11
12
13class Linear_with_bias(torch.nn.Module):14def __init__(self):15super(Linear_with_bias, self).__init__()16self.mlp = torch.nn.Linear(4096, 4096)17
18def forward(self, x):19return self.mlp(x)20
21
22class Linear_without_bias(torch.nn.Module):23def __init__(self):24super(Linear_without_bias, self).__init__()25self.mlp = torch.nn.Linear(4096, 4096, bias=False)26
27def forward(self, x):28return self.mlp(x)29
30
31class Linear_gelu(torch.nn.Module):32def __init__(self):33super(Linear_gelu, self).__init__()34self.mlp = torch.nn.Linear(4096, 4096)35
36def forward(self, x):37return torch.nn.functional.gelu(self.mlp(x))38
39
40class Linear_silu(torch.nn.Module):41def __init__(self):42super(Linear_silu, self).__init__()43self.mlp = torch.nn.Linear(4096, 4096, bias=False)44
45def forward(self, x):46return torch.nn.functional.silu(self.mlp(x))47
48
49class Linear_Gate_Up(torch.nn.Module):50def __init__(self, in_feature, out_feature, bias_gate, bias_up):51super(Linear_Gate_Up, self).__init__()52self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_gate)53self.up_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_up)54
55def forward(self, x):56return torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)57
58
59class Linear_relu(torch.nn.Module):60def __init__(self):61super(Linear_relu, self).__init__()62self.mlp = torch.nn.Linear(4096, 4096, bias=False)63
64def forward(self, x):65return torch.nn.functional.relu(self.mlp(x))66
67
68class Linear_mul(torch.nn.Module):69def __init__(self):70super(Linear_mul, self).__init__()71self.mlp = torch.nn.Linear(4096, 4096, bias=False)72
73def forward(self, x):74return self.mlp(x) * x75
76
77class Linear_add(torch.nn.Module):78def __init__(self):79super(Linear_add, self).__init__()80self.mlp = torch.nn.Linear(4096, 4096, bias=False)81
82def forward(self, x):83return self.mlp(x) + x84
85
86class Linear_add_add(torch.nn.Module):87def __init__(self):88super(Linear_add_add, self).__init__()89self.mlp = torch.nn.Linear(4096, 4096)90
91def forward(self, x):92return self.mlp(x) + x + x93
94
95class Linear_tpp_fallback_dnnl(torch.nn.Module):96def __init__(self):97super(Linear_tpp_fallback_dnnl, self).__init__()98self.mlp = torch.nn.Linear(4097, 4097)99
100def forward(self, x):101return self.mlp(x)102
103
104class TestTPPlinear(TestCase):105def test_tpp_linear_fallback(self):106x1 = torch.rand(1, 1, 4097)107x2 = copy.deepcopy(x1)108for dtype in [torch.float, torch.bfloat16]:109model = Linear_tpp_fallback_dnnl().eval()110
111with torch.no_grad(), torch.cpu.amp.autocast(112enabled=True if dtype is torch.bfloat16 else False113):114ref_out = model(x1)115
116_enable_tpp()117model = ipex.optimize(model, dtype=dtype)118with torch.no_grad(), torch.cpu.amp.autocast(119enabled=True if dtype is torch.bfloat16 else False120):121out = model(x2)122self.assertEqual(out, ref_out)123_disable_tpp()124
125def test_tpp_linear(self):126x1 = torch.rand(1, 1, 4096)127x2 = copy.deepcopy(x1)128for dtype in [torch.float, torch.bfloat16]:129model = Linear_with_bias().eval()130model_nb = Linear_without_bias().eval()131if dtype is torch.bfloat16:132x1 = x1.to(torch.bfloat16)133x2 = x2.to(torch.bfloat16)134model = model.to(torch.bfloat16)135model_nb = model_nb.to(torch.bfloat16)136ref_out = model(x1)137ref_out_nb = model_nb(x1)138
139_enable_tpp()140model = ipex.optimize(model, dtype=dtype)141model_nb = ipex.optimize(model_nb, dtype=dtype)142out = model(x2)143out_nb = model_nb(x2)144self.assertEqual(out, ref_out)145self.assertEqual(out_nb, ref_out_nb)146_disable_tpp()147
148def test_tpp_fused_gate_up_proj(self):149in_feature = 64150out_feature = 32151
152x = torch.randn(1, 4, in_feature)153x_tpp = copy.deepcopy(x)154
155with torch.no_grad():156for dtype, bias_gate, bias_up in itertools.product(157[torch.float, torch.bfloat16], [False, True], [False, True]158):159model = Linear_Gate_Up(160in_feature, out_feature, bias_gate, bias_up161).eval()162if dtype == torch.bfloat16:163x = x.to(torch.bfloat16)164x_tpp = x_tpp.to(torch.bfloat16)165model = model.to(torch.bfloat16)166ref_out = model(x)167
168_enable_tpp()169model = ipex.optimize(model, dtype=dtype)170out = torch.ops.torch_ipex.tpp_fused_gate_up_proj(171x_tpp,172model.gate_proj.weight,173model.gate_proj.bias,174model.up_proj.weight,175model.up_proj.bias,176)177
178out_linear_silu = torch.ops.torch_ipex.tpp_linear_silu(179x_tpp, model.gate_proj.weight, model.gate_proj.bias180)181out_tpp_ref = torch.ops.torch_ipex.tpp_linear_mul(182x_tpp, out_linear_silu, model.up_proj.weight, model.up_proj.bias183)184self.assertEqual(out, out_tpp_ref)185self.assertEqual(out, ref_out)186_disable_tpp()187
188def test_tpp_linear_gelu(self):189x1 = torch.rand(1, 4, 4096)190x2 = copy.deepcopy(x1)191with torch.no_grad():192for dtype in [torch.bfloat16]:193model = Linear_gelu().eval()194if dtype is torch.bfloat16:195x1 = x1.to(torch.bfloat16)196x2 = x2.to(torch.bfloat16)197model = model.to(torch.bfloat16)198ref_out = model(x1)199
200_enable_tpp()201model = ipex.optimize(model, dtype=dtype)202out = torch.ops.torch_ipex.tpp_linear_gelu(203x2, model.mlp.weight, model.mlp.bias204)205self.assertEqual(out, ref_out)206_disable_tpp()207
208def test_tpp_linear_silu(self):209x1 = torch.rand(1, 4, 4096)210x2 = copy.deepcopy(x1)211with torch.no_grad():212for dtype in [torch.bfloat16]:213model = Linear_silu().eval()214if dtype is torch.bfloat16:215x1 = x1.to(torch.bfloat16)216x2 = x2.to(torch.bfloat16)217model = model.to(torch.bfloat16)218ref_out = model(x1)219
220_enable_tpp()221model = ipex.optimize(model, dtype=dtype)222out = torch.ops.torch_ipex.tpp_linear_silu(223x2, model.mlp.weight, x2.new_empty(0)224)225self.assertEqual(out, ref_out)226_disable_tpp()227
228def test_tpp_linear_relu(self):229x1 = torch.rand(1, 4, 4096)230x2 = copy.deepcopy(x1)231with torch.no_grad():232for dtype in [torch.bfloat16]:233model = Linear_relu().eval()234if dtype is torch.bfloat16:235x1 = x1.to(torch.bfloat16)236x2 = x2.to(torch.bfloat16)237model = model.to(torch.bfloat16)238ref_out = model(x1)239
240_enable_tpp()241model = ipex.optimize(model, dtype=dtype)242out = torch.ops.torch_ipex.tpp_linear_relu(243x2, model.mlp.weight, x2.new_empty(0)244)245self.assertEqual(out, ref_out)246_disable_tpp()247
248def test_tpp_linear_mul(self):249x1 = torch.rand(1, 4, 4096)250x2 = copy.deepcopy(x1)251with torch.no_grad():252for dtype in [torch.bfloat16]:253model = Linear_mul().eval()254if dtype is torch.bfloat16:255x1 = x1.to(torch.bfloat16)256x2 = x2.to(torch.bfloat16)257model = model.to(torch.bfloat16)258ref_out = model(x1)259
260_enable_tpp()261model = ipex.optimize(model, dtype=dtype)262out = torch.ops.torch_ipex.tpp_linear_mul(263x2, x2, model.mlp.weight, x2.new_empty(0)264)265self.assertEqual(out, ref_out)266_disable_tpp()267
268def test_tpp_linear_add(self):269x1 = torch.rand(1, 4, 4096)270x2 = copy.deepcopy(x1)271with torch.no_grad():272for dtype in [torch.bfloat16]:273model = Linear_add().eval()274if dtype is torch.bfloat16:275x1 = x1.to(torch.bfloat16)276x2 = x2.to(torch.bfloat16)277model = model.to(torch.bfloat16)278ref_out = model(x1)279
280_enable_tpp()281model = ipex.optimize(model, dtype=dtype)282out = torch.ops.torch_ipex.tpp_linear_add(283x2, x2, model.mlp.weight, x2.new_empty(0), 1.0284)285self.assertEqual(out, ref_out)286_disable_tpp()287
288def test_tpp_linear_add2(self):289x1 = torch.rand(1, 4, 4096)290x2 = copy.deepcopy(x1)291with torch.no_grad():292for dtype in [torch.bfloat16]:293model = Linear_add_add().eval()294if dtype is torch.bfloat16:295x1 = x1.to(torch.bfloat16)296x2 = x2.to(torch.bfloat16)297model = model.to(torch.bfloat16)298ref_out = model(x1)299
300_enable_tpp()301model = ipex.optimize(model, dtype=dtype)302out = torch.ops.torch_ipex.tpp_linear_add_add(303x2, x2, x2, model.mlp.weight, model.mlp.bias, 1.0304)305self.assertEqual(out, ref_out)306_disable_tpp()307
308
309if __name__ == "__main__":310test = unittest.main()311