intel-extension-for-pytorch

Форк
0
310 строк · 10.3 Кб
1
import unittest
2
import itertools
3
import torch
4
import intel_extension_for_pytorch as ipex
5
from torch.testing._internal.common_utils import TestCase
6
import copy
7
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
8
    _enable_tpp,
9
    _disable_tpp,
10
)
11

12

13
class Linear_with_bias(torch.nn.Module):
14
    def __init__(self):
15
        super(Linear_with_bias, self).__init__()
16
        self.mlp = torch.nn.Linear(4096, 4096)
17

18
    def forward(self, x):
19
        return self.mlp(x)
20

21

22
class Linear_without_bias(torch.nn.Module):
23
    def __init__(self):
24
        super(Linear_without_bias, self).__init__()
25
        self.mlp = torch.nn.Linear(4096, 4096, bias=False)
26

27
    def forward(self, x):
28
        return self.mlp(x)
29

30

31
class Linear_gelu(torch.nn.Module):
32
    def __init__(self):
33
        super(Linear_gelu, self).__init__()
34
        self.mlp = torch.nn.Linear(4096, 4096)
35

36
    def forward(self, x):
37
        return torch.nn.functional.gelu(self.mlp(x))
38

39

40
class Linear_silu(torch.nn.Module):
41
    def __init__(self):
42
        super(Linear_silu, self).__init__()
43
        self.mlp = torch.nn.Linear(4096, 4096, bias=False)
44

45
    def forward(self, x):
46
        return torch.nn.functional.silu(self.mlp(x))
47

48

49
class Linear_Gate_Up(torch.nn.Module):
50
    def __init__(self, in_feature, out_feature, bias_gate, bias_up):
51
        super(Linear_Gate_Up, self).__init__()
52
        self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_gate)
53
        self.up_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_up)
54

55
    def forward(self, x):
56
        return torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)
57

58

59
class Linear_relu(torch.nn.Module):
60
    def __init__(self):
61
        super(Linear_relu, self).__init__()
62
        self.mlp = torch.nn.Linear(4096, 4096, bias=False)
63

64
    def forward(self, x):
65
        return torch.nn.functional.relu(self.mlp(x))
66

67

68
class Linear_mul(torch.nn.Module):
69
    def __init__(self):
70
        super(Linear_mul, self).__init__()
71
        self.mlp = torch.nn.Linear(4096, 4096, bias=False)
72

73
    def forward(self, x):
74
        return self.mlp(x) * x
75

76

77
class Linear_add(torch.nn.Module):
78
    def __init__(self):
79
        super(Linear_add, self).__init__()
80
        self.mlp = torch.nn.Linear(4096, 4096, bias=False)
81

82
    def forward(self, x):
83
        return self.mlp(x) + x
84

85

86
class Linear_add_add(torch.nn.Module):
87
    def __init__(self):
88
        super(Linear_add_add, self).__init__()
89
        self.mlp = torch.nn.Linear(4096, 4096)
90

91
    def forward(self, x):
92
        return self.mlp(x) + x + x
93

94

95
class Linear_tpp_fallback_dnnl(torch.nn.Module):
96
    def __init__(self):
97
        super(Linear_tpp_fallback_dnnl, self).__init__()
98
        self.mlp = torch.nn.Linear(4097, 4097)
99

100
    def forward(self, x):
101
        return self.mlp(x)
102

103

104
class TestTPPlinear(TestCase):
105
    def test_tpp_linear_fallback(self):
106
        x1 = torch.rand(1, 1, 4097)
107
        x2 = copy.deepcopy(x1)
108
        for dtype in [torch.float, torch.bfloat16]:
109
            model = Linear_tpp_fallback_dnnl().eval()
110

111
            with torch.no_grad(), torch.cpu.amp.autocast(
112
                enabled=True if dtype is torch.bfloat16 else False
113
            ):
114
                ref_out = model(x1)
115

116
            _enable_tpp()
117
            model = ipex.optimize(model, dtype=dtype)
118
            with torch.no_grad(), torch.cpu.amp.autocast(
119
                enabled=True if dtype is torch.bfloat16 else False
120
            ):
121
                out = model(x2)
122
            self.assertEqual(out, ref_out)
123
            _disable_tpp()
124

125
    def test_tpp_linear(self):
126
        x1 = torch.rand(1, 1, 4096)
127
        x2 = copy.deepcopy(x1)
128
        for dtype in [torch.float, torch.bfloat16]:
129
            model = Linear_with_bias().eval()
130
            model_nb = Linear_without_bias().eval()
131
            if dtype is torch.bfloat16:
132
                x1 = x1.to(torch.bfloat16)
133
                x2 = x2.to(torch.bfloat16)
134
                model = model.to(torch.bfloat16)
135
                model_nb = model_nb.to(torch.bfloat16)
136
            ref_out = model(x1)
137
            ref_out_nb = model_nb(x1)
138

139
            _enable_tpp()
140
            model = ipex.optimize(model, dtype=dtype)
141
            model_nb = ipex.optimize(model_nb, dtype=dtype)
142
            out = model(x2)
143
            out_nb = model_nb(x2)
144
            self.assertEqual(out, ref_out)
145
            self.assertEqual(out_nb, ref_out_nb)
146
            _disable_tpp()
147

148
    def test_tpp_fused_gate_up_proj(self):
149
        in_feature = 64
150
        out_feature = 32
151

152
        x = torch.randn(1, 4, in_feature)
153
        x_tpp = copy.deepcopy(x)
154

155
        with torch.no_grad():
156
            for dtype, bias_gate, bias_up in itertools.product(
157
                [torch.float, torch.bfloat16], [False, True], [False, True]
158
            ):
159
                model = Linear_Gate_Up(
160
                    in_feature, out_feature, bias_gate, bias_up
161
                ).eval()
162
                if dtype == torch.bfloat16:
163
                    x = x.to(torch.bfloat16)
164
                    x_tpp = x_tpp.to(torch.bfloat16)
165
                    model = model.to(torch.bfloat16)
166
                ref_out = model(x)
167

168
                _enable_tpp()
169
                model = ipex.optimize(model, dtype=dtype)
170
                out = torch.ops.torch_ipex.tpp_fused_gate_up_proj(
171
                    x_tpp,
172
                    model.gate_proj.weight,
173
                    model.gate_proj.bias,
174
                    model.up_proj.weight,
175
                    model.up_proj.bias,
176
                )
177

178
                out_linear_silu = torch.ops.torch_ipex.tpp_linear_silu(
179
                    x_tpp, model.gate_proj.weight, model.gate_proj.bias
180
                )
181
                out_tpp_ref = torch.ops.torch_ipex.tpp_linear_mul(
182
                    x_tpp, out_linear_silu, model.up_proj.weight, model.up_proj.bias
183
                )
184
                self.assertEqual(out, out_tpp_ref)
185
                self.assertEqual(out, ref_out)
186
                _disable_tpp()
187

188
    def test_tpp_linear_gelu(self):
189
        x1 = torch.rand(1, 4, 4096)
190
        x2 = copy.deepcopy(x1)
191
        with torch.no_grad():
192
            for dtype in [torch.bfloat16]:
193
                model = Linear_gelu().eval()
194
                if dtype is torch.bfloat16:
195
                    x1 = x1.to(torch.bfloat16)
196
                    x2 = x2.to(torch.bfloat16)
197
                    model = model.to(torch.bfloat16)
198
                ref_out = model(x1)
199

200
                _enable_tpp()
201
                model = ipex.optimize(model, dtype=dtype)
202
                out = torch.ops.torch_ipex.tpp_linear_gelu(
203
                    x2, model.mlp.weight, model.mlp.bias
204
                )
205
                self.assertEqual(out, ref_out)
206
                _disable_tpp()
207

208
    def test_tpp_linear_silu(self):
209
        x1 = torch.rand(1, 4, 4096)
210
        x2 = copy.deepcopy(x1)
211
        with torch.no_grad():
212
            for dtype in [torch.bfloat16]:
213
                model = Linear_silu().eval()
214
                if dtype is torch.bfloat16:
215
                    x1 = x1.to(torch.bfloat16)
216
                    x2 = x2.to(torch.bfloat16)
217
                    model = model.to(torch.bfloat16)
218
                ref_out = model(x1)
219

220
                _enable_tpp()
221
                model = ipex.optimize(model, dtype=dtype)
222
                out = torch.ops.torch_ipex.tpp_linear_silu(
223
                    x2, model.mlp.weight, x2.new_empty(0)
224
                )
225
                self.assertEqual(out, ref_out)
226
                _disable_tpp()
227

228
    def test_tpp_linear_relu(self):
229
        x1 = torch.rand(1, 4, 4096)
230
        x2 = copy.deepcopy(x1)
231
        with torch.no_grad():
232
            for dtype in [torch.bfloat16]:
233
                model = Linear_relu().eval()
234
                if dtype is torch.bfloat16:
235
                    x1 = x1.to(torch.bfloat16)
236
                    x2 = x2.to(torch.bfloat16)
237
                    model = model.to(torch.bfloat16)
238
                ref_out = model(x1)
239

240
                _enable_tpp()
241
                model = ipex.optimize(model, dtype=dtype)
242
                out = torch.ops.torch_ipex.tpp_linear_relu(
243
                    x2, model.mlp.weight, x2.new_empty(0)
244
                )
245
                self.assertEqual(out, ref_out)
246
                _disable_tpp()
247

248
    def test_tpp_linear_mul(self):
249
        x1 = torch.rand(1, 4, 4096)
250
        x2 = copy.deepcopy(x1)
251
        with torch.no_grad():
252
            for dtype in [torch.bfloat16]:
253
                model = Linear_mul().eval()
254
                if dtype is torch.bfloat16:
255
                    x1 = x1.to(torch.bfloat16)
256
                    x2 = x2.to(torch.bfloat16)
257
                    model = model.to(torch.bfloat16)
258
                ref_out = model(x1)
259

260
                _enable_tpp()
261
                model = ipex.optimize(model, dtype=dtype)
262
                out = torch.ops.torch_ipex.tpp_linear_mul(
263
                    x2, x2, model.mlp.weight, x2.new_empty(0)
264
                )
265
                self.assertEqual(out, ref_out)
266
                _disable_tpp()
267

268
    def test_tpp_linear_add(self):
269
        x1 = torch.rand(1, 4, 4096)
270
        x2 = copy.deepcopy(x1)
271
        with torch.no_grad():
272
            for dtype in [torch.bfloat16]:
273
                model = Linear_add().eval()
274
                if dtype is torch.bfloat16:
275
                    x1 = x1.to(torch.bfloat16)
276
                    x2 = x2.to(torch.bfloat16)
277
                    model = model.to(torch.bfloat16)
278
                ref_out = model(x1)
279

280
                _enable_tpp()
281
                model = ipex.optimize(model, dtype=dtype)
282
                out = torch.ops.torch_ipex.tpp_linear_add(
283
                    x2, x2, model.mlp.weight, x2.new_empty(0), 1.0
284
                )
285
                self.assertEqual(out, ref_out)
286
                _disable_tpp()
287

288
    def test_tpp_linear_add2(self):
289
        x1 = torch.rand(1, 4, 4096)
290
        x2 = copy.deepcopy(x1)
291
        with torch.no_grad():
292
            for dtype in [torch.bfloat16]:
293
                model = Linear_add_add().eval()
294
                if dtype is torch.bfloat16:
295
                    x1 = x1.to(torch.bfloat16)
296
                    x2 = x2.to(torch.bfloat16)
297
                    model = model.to(torch.bfloat16)
298
                ref_out = model(x1)
299

300
                _enable_tpp()
301
                model = ipex.optimize(model, dtype=dtype)
302
                out = torch.ops.torch_ipex.tpp_linear_add_add(
303
                    x2, x2, x2, model.mlp.weight, model.mlp.bias, 1.0
304
                )
305
                self.assertEqual(out, ref_out)
306
                _disable_tpp()
307

308

309
if __name__ == "__main__":
310
    test = unittest.main()
311

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.