intel-extension-for-pytorch
54 строки · 1.8 Кб
1import unittest
2import torch
3import intel_extension_for_pytorch as ipex
4from torch.testing._internal.common_utils import TestCase
5import copy
6
7
8class MLP(torch.nn.Module):
9def __init__(self):
10super(MLP, self).__init__()
11self.mlp = torch.nn.ModuleList()
12self.mlp.append(torch.nn.Linear(10, 10))
13self.mlp.append(torch.nn.ReLU())
14self.mlp.append(torch.nn.Linear(10, 10))
15self.mlp.append(torch.nn.Sigmoid())
16
17def forward(self, x):
18for layer in self.mlp:
19x = layer(x)
20return x
21
22
23class TestLinearFuseEltwise(TestCase):
24def test_linear_fuse_eltwise(self):
25x1 = torch.rand(5, 10).requires_grad_()
26x2 = copy.deepcopy(x1)
27for dtype in [torch.float, torch.bfloat16]:
28model = MLP()
29opt = torch.optim.SGD(model.parameters(), lr=0.01)
30model, opt = ipex.optimize(
31model, optimizer=opt, dtype=dtype, auto_kernel_selection=True
32)
33with torch.cpu.amp.autocast(enabled=(dtype == torch.bfloat16)):
34ref_out = model(x1).sum()
35ref_out.backward()
36
37fused_model = copy.deepcopy(model)
38fused_model.mlp[0] = ipex.nn.modules.IPEXLinearEltwise(
39fused_model.mlp[0], "relu"
40)
41fused_model.mlp[1] = torch.nn.Identity()
42fused_model.mlp[2] = ipex.nn.modules.IPEXLinearEltwise(
43fused_model.mlp[2], "sigmoid"
44)
45fused_model.mlp[3] = torch.nn.Identity()
46with torch.cpu.amp.autocast(enabled=(dtype == torch.bfloat16)):
47out = fused_model(x2).sum()
48out.backward()
49self.assertEqual(out, ref_out)
50self.assertEqual(x1.grad, x2.grad)
51
52
53if __name__ == "__main__":
54test = unittest.main()
55