intel-extension-for-pytorch

Форк
0
/
test_add_layernorm.py 
56 строк · 2.3 Кб
1
import unittest
2

3
import torch
4
from common_utils import TestCase
5

6

7
class add_layernorm(torch.nn.Module):
8
    def __init__(self, size):
9
        super(add_layernorm, self).__init__()
10
        self.layer_norm = torch.nn.LayerNorm(size)
11

12
    def forward(self, a, b):
13
        x = torch.add(a, b)
14
        x = self.layer_norm(x)
15
        return x
16

17

18
class AddLayerNormTester(TestCase):
19
    def test_add_layernorm(self):
20
        for size in [10, 16, 35]:
21
            for dim in [2, 3, 4, 5]:
22
                for dtype in [torch.bfloat16, torch.float16]:
23
                    with torch.cpu.amp.autocast(dtype=dtype), torch.no_grad():
24
                        input_size = [
25
                            3,
26
                        ]
27
                        for _ in range(dim - 1):
28
                            input_size.append(size)
29
                        # add_layernorm input is fp32
30
                        a = torch.randn(input_size)
31
                        b = torch.randn(input_size)
32
                        model = add_layernorm(size).eval()
33
                        trace_model = torch.jit.trace(model, (a, b))
34
                        y1_fp32 = model(a, b)
35
                        y2_fp32 = trace_model(a, b)
36
                        self.assertEqual(y1_fp32.dtype, torch.float32)
37
                        self.assertEqual(y2_fp32.dtype, torch.float32)
38
                        self.assertEqual(y1_fp32, y2_fp32)
39

40
                        # add_layernorm input is bfloat16/float16
41
                        a_lowp = a.to(dtype=dtype)
42
                        b_lowp = b.to(dtype=dtype)
43
                        model = model.to(dtype=dtype)
44
                        trace_model = torch.jit.trace(model, (a_lowp, b_lowp))
45
                        y1_lowp = model(a_lowp, b_lowp)
46
                        y2_lowp = trace_model(a_lowp, b_lowp)
47
                        self.assertEqual(y1_lowp.dtype, dtype)
48
                        self.assertEqual(y2_lowp.dtype, dtype)
49
                        # Add a custom threshold for bf16/fp16 test because of fused add_layernorm in jit has higher precision
50
                        # and causes mismatch with eager mode.
51
                        prec = 5e-2 if dtype == torch.bfloat16 else 5e-3
52
                        self.assertEqual(y1_lowp, y2_lowp, prec=prec)
53

54

55
if __name__ == "__main__":
56
    test = unittest.main()
57

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.