intel-extension-for-pytorch
56 строк · 2.3 Кб
1import unittest
2
3import torch
4from common_utils import TestCase
5
6
7class add_layernorm(torch.nn.Module):
8def __init__(self, size):
9super(add_layernorm, self).__init__()
10self.layer_norm = torch.nn.LayerNorm(size)
11
12def forward(self, a, b):
13x = torch.add(a, b)
14x = self.layer_norm(x)
15return x
16
17
18class AddLayerNormTester(TestCase):
19def test_add_layernorm(self):
20for size in [10, 16, 35]:
21for dim in [2, 3, 4, 5]:
22for dtype in [torch.bfloat16, torch.float16]:
23with torch.cpu.amp.autocast(dtype=dtype), torch.no_grad():
24input_size = [
253,
26]
27for _ in range(dim - 1):
28input_size.append(size)
29# add_layernorm input is fp32
30a = torch.randn(input_size)
31b = torch.randn(input_size)
32model = add_layernorm(size).eval()
33trace_model = torch.jit.trace(model, (a, b))
34y1_fp32 = model(a, b)
35y2_fp32 = trace_model(a, b)
36self.assertEqual(y1_fp32.dtype, torch.float32)
37self.assertEqual(y2_fp32.dtype, torch.float32)
38self.assertEqual(y1_fp32, y2_fp32)
39
40# add_layernorm input is bfloat16/float16
41a_lowp = a.to(dtype=dtype)
42b_lowp = b.to(dtype=dtype)
43model = model.to(dtype=dtype)
44trace_model = torch.jit.trace(model, (a_lowp, b_lowp))
45y1_lowp = model(a_lowp, b_lowp)
46y2_lowp = trace_model(a_lowp, b_lowp)
47self.assertEqual(y1_lowp.dtype, dtype)
48self.assertEqual(y2_lowp.dtype, dtype)
49# Add a custom threshold for bf16/fp16 test because of fused add_layernorm in jit has higher precision
50# and causes mismatch with eager mode.
51prec = 5e-2 if dtype == torch.bfloat16 else 5e-3
52self.assertEqual(y1_lowp, y2_lowp, prec=prec)
53
54
55if __name__ == "__main__":
56test = unittest.main()
57