intel-extension-for-pytorch

Форк
0
/
test_linear_fuse_eltwise.py 
54 строки · 1.8 Кб
1
import unittest
2
import torch
3
import intel_extension_for_pytorch as ipex
4
from torch.testing._internal.common_utils import TestCase
5
import copy
6

7

8
class MLP(torch.nn.Module):
9
    def __init__(self):
10
        super(MLP, self).__init__()
11
        self.mlp = torch.nn.ModuleList()
12
        self.mlp.append(torch.nn.Linear(10, 10))
13
        self.mlp.append(torch.nn.ReLU())
14
        self.mlp.append(torch.nn.Linear(10, 10))
15
        self.mlp.append(torch.nn.Sigmoid())
16

17
    def forward(self, x):
18
        for layer in self.mlp:
19
            x = layer(x)
20
        return x
21

22

23
class TestLinearFuseEltwise(TestCase):
24
    def test_linear_fuse_eltwise(self):
25
        x1 = torch.rand(5, 10).requires_grad_()
26
        x2 = copy.deepcopy(x1)
27
        for dtype in [torch.float, torch.bfloat16]:
28
            model = MLP()
29
            opt = torch.optim.SGD(model.parameters(), lr=0.01)
30
            model, opt = ipex.optimize(
31
                model, optimizer=opt, dtype=dtype, auto_kernel_selection=True
32
            )
33
            with torch.cpu.amp.autocast(enabled=(dtype == torch.bfloat16)):
34
                ref_out = model(x1).sum()
35
            ref_out.backward()
36

37
            fused_model = copy.deepcopy(model)
38
            fused_model.mlp[0] = ipex.nn.modules.IPEXLinearEltwise(
39
                fused_model.mlp[0], "relu"
40
            )
41
            fused_model.mlp[1] = torch.nn.Identity()
42
            fused_model.mlp[2] = ipex.nn.modules.IPEXLinearEltwise(
43
                fused_model.mlp[2], "sigmoid"
44
            )
45
            fused_model.mlp[3] = torch.nn.Identity()
46
            with torch.cpu.amp.autocast(enabled=(dtype == torch.bfloat16)):
47
                out = fused_model(x2).sum()
48
            out.backward()
49
            self.assertEqual(out, ref_out)
50
            self.assertEqual(x1.grad, x2.grad)
51

52

53
if __name__ == "__main__":
54
    test = unittest.main()
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.