intel-extension-for-pytorch

Форк
0
613 строк · 23.5 Кб
1
import torch
2
import torch.nn as nn
3
from torch.testing._internal.jit_utils import JitTestCase
4
import unittest
5
import torch.nn.functional as F
6
import time
7

8

9
def get_rand_seed():
10
    return int(time.time() * 1000000000)
11

12

13
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
14

15
from typing import Dict, NamedTuple
16

17

18
class EltwiseFusionOp(NamedTuple):
19
    ipex_eltwise_op: str
20
    op_input_list: Dict = {}
21

22

23
unary_PyTorch_op_to_IPEX_op_map = {
24
    torch.relu: EltwiseFusionOp("relu"),
25
    torch.relu_: EltwiseFusionOp("relu_"),
26
    torch.abs: EltwiseFusionOp("abs"),
27
    torch.abs_: EltwiseFusionOp("abs_"),
28
    torch.exp: EltwiseFusionOp("exp"),
29
    torch.exp_: EltwiseFusionOp("exp_"),
30
    nn.Hardswish(inplace=False): EltwiseFusionOp("hardswish"),
31
    nn.Hardswish(inplace=True): EltwiseFusionOp("hardswish_"),
32
    torch.log: EltwiseFusionOp("log"),
33
    torch.log_: EltwiseFusionOp("log_"),
34
    nn.Mish(inplace=False): EltwiseFusionOp("mish"),
35
    nn.Mish(inplace=True): EltwiseFusionOp("mish_"),
36
    torch.sigmoid: EltwiseFusionOp("sigmoid"),
37
    torch.sigmoid_: EltwiseFusionOp("sigmoid_"),
38
    torch.round: EltwiseFusionOp("round"),
39
    torch.round_: EltwiseFusionOp("round_"),
40
    torch.sqrt: EltwiseFusionOp("sqrt"),
41
    torch.sqrt_: EltwiseFusionOp("sqrt_"),
42
    torch.square: EltwiseFusionOp("square"),
43
    torch.square_: EltwiseFusionOp("square_"),
44
    torch.tanh: EltwiseFusionOp("tanh"),
45
    torch.tanh_: EltwiseFusionOp("tanh_"),
46
    nn.SiLU(inplace=False): EltwiseFusionOp("silu"),
47
    nn.SiLU(inplace=True): EltwiseFusionOp("silu_"),
48
    nn.Hardsigmoid(inplace=False): EltwiseFusionOp("hardsigmoid"),
49
    nn.Hardsigmoid(inplace=True): EltwiseFusionOp("hardsigmoid_"),
50
}
51

52
non_unary_PyTorch_op_to_IPEX_op_map = {
53
    torch.clamp: EltwiseFusionOp("clamp", op_input_list={"min": -2, "max": 3}),
54
    torch.clamp_: EltwiseFusionOp("clamp_", op_input_list={"min": -2, "max": 3}),
55
    nn.GELU(approximate="none"): EltwiseFusionOp("gelu(none)"),
56
    nn.GELU(approximate="tanh"): EltwiseFusionOp("gelu(tanh)"),
57
    nn.ELU(inplace=False): EltwiseFusionOp("elu"),
58
    nn.ELU(inplace=True): EltwiseFusionOp("elu_"),
59
    torch.pow: EltwiseFusionOp("pow", op_input_list={"exponent": 2}),
60
    lambda t: t.pow_(2): EltwiseFusionOp("pow_"),
61
    nn.LeakyReLU(negative_slope=0.02, inplace=False): EltwiseFusionOp("leaky_relu"),
62
    nn.LeakyReLU(negative_slope=0.02, inplace=True): EltwiseFusionOp("leaky_relu_"),
63
}
64

65

66
class ConvEltwise(nn.Module):
67
    def __init__(
68
        self,
69
        eltwise_fn,
70
        dim,
71
        in_channels,
72
        out_channels,
73
        kernel_size,
74
        image_size,
75
        **kwargs
76
    ):
77
        super(ConvEltwise, self).__init__()
78
        self.conv = conv_module[dim](in_channels, out_channels, kernel_size)
79
        self.eltwise = eltwise_fn
80
        self.kwargs = kwargs
81

82
    def forward(self, x):
83
        a = self.conv(x)
84
        b = self.eltwise(a, **self.kwargs)
85
        return b
86

87

88
class IPEXConvAdd(nn.Module):
89
    def __init__(self, in_channels, out_channels, **kwargs):
90
        super(IPEXConvAdd, self).__init__()
91
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
92
        self.conv2 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
93

94
    def forward(self, x):
95
        a = self.conv1(x)
96
        b = self.conv2(x)
97
        return a.add_(b)
98

99

100
class IPEXConvAddRelu(nn.Module):
101
    def __init__(self, in_channels, out_channels, **kwargs):
102
        super(IPEXConvAddRelu, self).__init__()
103
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
104
        self.conv2 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
105

106
    def forward(self, x):
107
        a = F.relu(self.conv1(x))
108
        b = self.conv2(x)
109
        return F.relu(a.add_(b), inplace=True)
110

111

112
class IPEXConvConvRelu(nn.Module):
113
    def __init__(self, in_channels, out_channels, **kwargs):
114
        super(IPEXConvConvRelu, self).__init__()
115
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
116
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, bias=False, **kwargs)
117

118
    def forward(self, x):
119
        res = self.conv1(x)
120
        res = self.conv2(res)
121
        return F.relu(res, inplace=True)
122

123

124
class IPEXConvSigmoidMul(nn.Module):
125
    def __init__(self, in_channels, out_channels, **kwargs):
126
        super(IPEXConvSigmoidMul, self).__init__()
127
        self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
128

129
    def forward(self, x):
130
        a = self.conv(x)
131
        b = torch.sigmoid(a)
132
        return a.mul_(b)
133

134

135
class LinearEltwise(nn.Module):
136
    def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs):
137
        super(LinearEltwise, self).__init__()
138
        self.linear = nn.Linear(in_channels, out_channels, bias=bias)
139
        self.eltwise = eltwise_fn
140
        self.kwargs = kwargs
141

142
    def forward(self, x):
143
        a = self.linear(x)
144
        a = a / 2
145
        b = self.eltwise(a, **self.kwargs)
146
        return b
147

148

149
class IPEXLinearAdd(nn.Module):
150
    def __init__(self, in_channels, out_channels, bias):
151
        super(IPEXLinearAdd, self).__init__()
152
        self.linear1 = nn.Linear(in_channels, out_channels, bias=bias)
153
        self.linear2 = nn.Linear(in_channels, out_channels, bias=bias)
154

155
    def forward(self, x):
156
        a = self.linear1(x)
157
        b = self.linear2(x)
158
        return a.add_(b)
159

160

161
class IPEXLinearAddRelu(nn.Module):
162
    def __init__(self, in_channels, out_channels, bias):
163
        super(IPEXLinearAddRelu, self).__init__()
164
        self.linear = nn.Linear(in_channels, out_channels, bias=bias)
165

166
    def forward(self, x):
167
        a = F.relu(self.linear(x))
168
        b = self.linear(x)
169
        return F.relu(a.add_(b), inplace=True)
170

171

172
class IPEXLinearSigmoidMul(nn.Module):
173
    def __init__(self, in_channels, out_channels, bias):
174
        super(IPEXLinearSigmoidMul, self).__init__()
175
        self.linear = nn.Linear(in_channels, out_channels, bias=bias)
176

177
    def forward(self, x):
178
        a = self.linear(x)
179
        b = torch.sigmoid(a)
180
        return a.mul_(b)
181

182

183
class IPEXMatmulDiv(nn.Module):
184
    def __init__(self):
185
        super(IPEXMatmulDiv, self).__init__()
186
        seed = 2018
187
        torch.manual_seed(seed)
188

189
    def forward(self, x1, x2, x3):
190
        return torch.matmul(x1, x2) / x3 + x3
191

192

193
class TestTE(JitTestCase):
194
    def test_ipex_unary_conv_fusion(self, op_list=unary_PyTorch_op_to_IPEX_op_map):
195
        old = torch._C._debug_get_fusion_group_inlining()
196
        torch._C._debug_set_fusion_group_inlining(False)
197
        dim = 2
198
        out_channels = 16
199
        in_channels = 3
200
        kernel_size = 3
201
        for eltwise in op_list:
202
            rand_seed = int(get_rand_seed())
203
            torch.manual_seed(rand_seed)
204
            fusion_op = op_list[eltwise]
205
            ipex_eltwise_op = fusion_op.ipex_eltwise_op
206
            print("TEST conv2d+%s" % ipex_eltwise_op)
207
            for use_channels_last in [0, 1]:
208
                for batch_size, image_size in [[8, 20], [3, 256]]:
209
                    input_size = [batch_size, in_channels, image_size, image_size]
210
                    x = torch.randn(input_size)
211
                    te_model = ConvEltwise(
212
                        eltwise, dim, in_channels, out_channels, kernel_size, image_size
213
                    ).eval()
214
                    if use_channels_last:
215
                        x = x.to(memory_format=torch.channels_last)
216
                        te_model = te_model.to(memory_format=torch.channels_last)
217
                    te_model_traced = torch.jit.trace(te_model, (x))
218
                    te_model_traced = torch.jit.freeze(te_model_traced)
219
                    te_model_traced(x)
220
                    # self.assertAllFused(te_model_traced.graph_for(x))
221

222
                    res_jit = te_model_traced(x)
223
                    res_imperative = te_model(x)
224
                    self.assertEqual(
225
                        res_jit,
226
                        res_imperative,
227
                        "{}, {}".format(res_jit, res_imperative),
228
                    )
229
        torch._C._debug_set_fusion_group_inlining(old)
230

231
    def test_ipex_non_unary_conv_fusion(
232
        self, op_list=non_unary_PyTorch_op_to_IPEX_op_map
233
    ):
234
        old = torch._C._debug_get_fusion_group_inlining()
235
        torch._C._debug_set_fusion_group_inlining(False)
236
        dim = 2
237
        out_channels = 16
238
        in_channels = 3
239
        kernel_size = 3
240
        for eltwise in op_list:
241
            rand_seed = int(get_rand_seed())
242
            torch.manual_seed(rand_seed)
243
            fusion_op = op_list[eltwise]
244
            ipex_eltwise_op = fusion_op.ipex_eltwise_op
245
            print("TEST conv2d+%s" % ipex_eltwise_op)
246
            for use_channels_last in [0, 1]:
247
                for batch_size, image_size in [[8, 20], [3, 256]]:
248
                    input_size = [batch_size, in_channels, image_size, image_size]
249
                    x = torch.randn(input_size)
250
                    op_input_list = fusion_op.op_input_list
251
                    te_model = ConvEltwise(
252
                        eltwise,
253
                        dim,
254
                        in_channels,
255
                        out_channels,
256
                        kernel_size,
257
                        image_size,
258
                        **op_input_list
259
                    ).eval()
260
                    if use_channels_last:
261
                        x = x.to(memory_format=torch.channels_last)
262
                        te_model = te_model.to(memory_format=torch.channels_last)
263
                    te_model_traced = torch.jit.trace(te_model, (x))
264
                    te_model_traced = torch.jit.freeze(te_model_traced)
265
                    te_model_traced(x)
266
                    # self.assertAllFused(te_model_traced.graph_for(x))
267

268
                    res_jit = te_model_traced(x)
269
                    res_imperative = te_model(x)
270
                    self.assertEqual(
271
                        res_jit,
272
                        res_imperative,
273
                        "{}, {}".format(res_jit, res_imperative),
274
                    )
275
        torch._C._debug_set_fusion_group_inlining(old)
276

277
    def test_ipex_conv_add(self):
278
        old = torch._C._debug_get_fusion_group_inlining()
279
        torch._C._debug_set_fusion_group_inlining(False)
280
        print("TEST conv2d+add")
281
        rand_seed = int(get_rand_seed())
282
        torch.manual_seed(rand_seed)
283
        for use_channels_last in [0, 1]:
284
            te_model = IPEXConvAdd(3, 2, kernel_size=(3, 3)).eval()
285
            x = torch.randn(1, 3, 10, 10)
286
            if use_channels_last:
287
                x = x.to(memory_format=torch.channels_last)
288
                te_model = te_model.to(memory_format=torch.channels_last)
289
            te_model_traced = torch.jit.trace(te_model, (x))
290
            te_model_traced = torch.jit.freeze(te_model_traced)
291
            te_model_traced(x)
292
            # self.assertAllFused(te_model_traced.graph_for(x))
293

294
            res_jit = te_model_traced(x)
295
            res_imperative = te_model(x)
296
            self.assertEqual(res_jit, res_imperative)
297

298
            x = torch.randn(3, 3, 20, 20)
299
            res_jit = te_model_traced(x)
300
            res_imperative = te_model(x)
301
            self.assertEqual(res_jit, res_imperative)
302

303
        torch._C._debug_set_fusion_group_inlining(old)
304

305
    def test_ipex_conv_add_relu(self):
306
        old = torch._C._debug_get_fusion_group_inlining()
307
        torch._C._debug_set_fusion_group_inlining(False)
308
        print("TEST conv2d+add+relu")
309
        rand_seed = int(get_rand_seed())
310
        torch.manual_seed(rand_seed)
311
        for use_channels_last in [0, 1]:
312
            te_model = IPEXConvAddRelu(3, 2, kernel_size=(3, 3)).eval()
313
            x = torch.randn(1, 3, 10, 10)
314
            if use_channels_last:
315
                x = x.to(memory_format=torch.channels_last)
316
                te_model = te_model.to(memory_format=torch.channels_last)
317
            te_model_traced = torch.jit.trace(te_model, (x))
318
            te_model_traced = torch.jit.freeze(te_model_traced)
319
            te_model_traced(x)
320
            # self.assertAllFused(te_model_traced.graph_for(x))
321

322
            res_jit = te_model_traced(x)
323
            res_imperative = te_model(x)
324
            self.assertEqual(res_jit, res_imperative)
325

326
            x = torch.randn(3, 3, 20, 20)
327
            res_jit = te_model_traced(x)
328
            res_imperative = te_model(x)
329
            self.assertEqual(res_jit, res_imperative)
330

331
        torch._C._debug_set_fusion_group_inlining(old)
332

333
    def test_ipex_conv_conv_relu(self):
334
        old = torch._C._debug_get_fusion_group_inlining()
335
        torch._C._debug_set_fusion_group_inlining(False)
336
        print("TEST conv bottleneck")
337
        rand_seed = int(get_rand_seed())
338
        torch.manual_seed(rand_seed)
339
        for use_channels_last in [0, 1]:
340
            te_model = IPEXConvConvRelu(3, 10, kernel_size=(3, 3)).eval()
341
            x = torch.randn(1, 3, 224, 224)
342
            if use_channels_last:
343
                x = x.to(memory_format=torch.channels_last)
344
                te_model = te_model.to(memory_format=torch.channels_last)
345
            te_model_traced = torch.jit.script(te_model)
346
            te_model_traced = torch.jit.freeze(te_model_traced)
347
            te_model_traced(x)
348

349
            # self.assertAllFused(te_model_traced.graph_for(x))
350

351
            res_jit = te_model_traced(x)
352
            res_imperative = te_model(x)
353
            self.assertEqual(res_jit, res_imperative)
354

355
            x = torch.randn(3, 3, 500, 500)
356
            res_jit = te_model_traced(x)
357
            res_imperative = te_model(x)
358
            self.assertEqual(res_jit, res_imperative)
359

360
        torch._C._debug_set_fusion_group_inlining(old)
361

362
    def test_ipex_conv_sigmoid_mul(self):
363
        old = torch._C._debug_get_fusion_group_inlining()
364
        torch._C._debug_set_fusion_group_inlining(False)
365
        print("TEST conv2d+sigmoid+mul")
366
        rand_seed = int(get_rand_seed())
367
        torch.manual_seed(rand_seed)
368
        for use_channels_last in [0, 1]:
369
            te_model = IPEXConvSigmoidMul(3, 2, kernel_size=(3, 3)).eval()
370
            x = torch.randn(1, 3, 10, 10)
371
            if use_channels_last:
372
                x = x.to(memory_format=torch.channels_last)
373
                te_model = te_model.to(memory_format=torch.channels_last)
374
            te_model_traced = torch.jit.trace(te_model, (x))
375
            te_model_traced = torch.jit.freeze(te_model_traced)
376
            te_model_traced(x)
377
            # self.assertAllFused(te_model_traced.graph_for(x))
378

379
            res_jit = te_model_traced(x)
380
            res_imperative = te_model(x)
381
            self.assertEqual(res_jit, res_imperative)
382

383
            x = torch.randn(3, 3, 20, 20)
384
            res_jit = te_model_traced(x)
385
            res_imperative = te_model(x)
386
            self.assertEqual(res_jit, res_imperative)
387

388
        torch._C._debug_set_fusion_group_inlining(old)
389

390
    def test_ipex_matmul_div(self):
391
        print("TEST conv matmul+div")
392
        te_matmul_div = IPEXMatmulDiv()
393
        rand_seed = int(get_rand_seed())
394
        torch.manual_seed(rand_seed)
395
        x1 = torch.randn(5, 5)
396
        x2 = torch.randn(5, 5)
397
        x3 = torch.randn(5, 5)
398
        te_matmul_div_traced = torch.jit.script(te_matmul_div).eval()
399
        te_matmul_div_traced = torch.jit.freeze(te_matmul_div_traced)
400
        te_matmul_div_traced(x1, x2, x3)
401
        # self.assertAllFused(te_matmul_div_traced.graph_for(x1, x2, x3))
402
        res_jit = te_matmul_div_traced(x1, x2, x3)
403
        res_imperative = te_matmul_div(x1, x2, x3)
404
        self.assertEqual(res_jit, res_imperative)
405

406
    def test_ipex_unary_linear_fusion(self, op_list=unary_PyTorch_op_to_IPEX_op_map):
407
        old = torch._C._debug_get_fusion_group_inlining()
408
        torch._C._debug_set_fusion_group_inlining(False)
409
        batch_size = 3
410
        out_channels = 32
411
        in_channels = 3
412
        for eltwise in op_list:
413
            rand_seed = int(get_rand_seed())
414
            torch.manual_seed(rand_seed)
415
            fusion_op = op_list[eltwise]
416
            ipex_eltwise_op = fusion_op.ipex_eltwise_op
417
            """ # Issue of "round" 
418
                The OP "round" in ideep has numeric issue when input is exactly 0.500,
419
                so we fix the seed here for "round".
420
                For example:
421
                    x = torch.Tensor([0.500])
422
                    ideep: 1.0 = torch.round(x)
423
                    expected: 0.0 = torch.round(x)
424
                The seed to reproduce the failure: 1665593217573048320
425
            """
426
            if "round" in ipex_eltwise_op:
427
                torch.manual_seed(1665594679504775936)
428
            print("TEST linear+%s" % ipex_eltwise_op)
429
            for bias in [True, False]:
430
                input_size = [batch_size, in_channels]
431
                x = torch.randn(input_size)
432
                # linear fusion only supports bf16
433
                with torch.cpu.amp.autocast(
434
                    enabled=True, dtype=torch.bfloat16
435
                ), torch.no_grad():
436
                    te_model = LinearEltwise(
437
                        eltwise, in_channels, out_channels, bias
438
                    ).eval()
439
                    te_model_traced = torch.jit.trace(te_model, (x))
440
                    te_model_traced = torch.jit.freeze(te_model_traced)
441
                    te_model_traced(x)
442
                    # self.assertAllFused(te_model_traced.graph_for(x))
443

444
                    res_jit = te_model_traced(x)
445
                    res_imperative = te_model(x)
446
                self.assertEqual(
447
                    res_jit,
448
                    res_imperative,
449
                    rtol=0.02,
450
                    atol=0.01,
451
                    msg="{}, {}".format(res_jit, res_imperative),
452
                )
453
        torch._C._debug_set_fusion_group_inlining(old)
454

455
    def test_ipex_non_unary_linear_fusion(
456
        self, op_list=non_unary_PyTorch_op_to_IPEX_op_map
457
    ):
458
        old = torch._C._debug_get_fusion_group_inlining()
459
        torch._C._debug_set_fusion_group_inlining(False)
460
        batch_size = 3
461
        out_channels = 32
462
        in_channels = 3
463
        for eltwise in op_list:
464
            rand_seed = int(get_rand_seed())
465
            torch.manual_seed(rand_seed)
466
            fusion_op = op_list[eltwise]
467
            ipex_eltwise_op = fusion_op.ipex_eltwise_op
468
            print("TEST linear+%s" % ipex_eltwise_op)
469
            for bias in [True, False]:
470
                input_size = [batch_size, in_channels]
471
                x = torch.randn(input_size)
472
                op_input_list = fusion_op.op_input_list
473
                # linear fusion only supports bf16
474
                with torch.cpu.amp.autocast(
475
                    enabled=True, dtype=torch.bfloat16
476
                ), torch.no_grad():
477
                    te_model = LinearEltwise(
478
                        eltwise, in_channels, out_channels, bias, **op_input_list
479
                    ).eval()
480
                    te_model_traced = torch.jit.trace(te_model, (x))
481
                    te_model_traced = torch.jit.freeze(te_model_traced)
482
                    te_model_traced(x)
483
                    # self.assertAllFused(te_model_traced.graph_for(x))
484

485
                    res_jit = te_model_traced(x)
486
                    res_imperative = te_model(x)
487
                self.assertEqual(
488
                    res_jit,
489
                    res_imperative,
490
                    rtol=0.02,
491
                    atol=0.01,
492
                    msg="{}, {}".format(res_jit, res_imperative),
493
                )
494
        torch._C._debug_set_fusion_group_inlining(old)
495

496
    def test_ipex_linear_add(self):
497
        old = torch._C._debug_get_fusion_group_inlining()
498
        torch._C._debug_set_fusion_group_inlining(False)
499
        print("TEST linear+add")
500
        rand_seed = int(get_rand_seed())
501
        torch.manual_seed(rand_seed)
502
        for bias in [True, False]:
503
            with torch.cpu.amp.autocast(
504
                enabled=True, dtype=torch.bfloat16
505
            ), torch.no_grad():
506
                te_model = IPEXLinearAdd(3, 32, bias).eval()
507
                x = torch.randn(3, 3)
508
                te_model_traced = torch.jit.trace(te_model, (x))
509
                te_model_traced = torch.jit.freeze(te_model_traced)
510
                te_model_traced(x)
511
                # self.assertAllFused(te_model_traced.graph_for(x))
512

513
                res_jit = te_model_traced(x)
514
                res_imperative = te_model(x)
515
                self.assertEqual(
516
                    res_jit,
517
                    res_imperative,
518
                    rtol=0.02,
519
                    atol=0.01,
520
                    msg="{}, {}".format(res_jit, res_imperative),
521
                )
522

523
                x = torch.randn(8, 3)
524
                res_jit = te_model_traced(x)
525
                res_imperative = te_model(x)
526
                self.assertEqual(
527
                    res_jit,
528
                    res_imperative,
529
                    rtol=0.02,
530
                    atol=0.01,
531
                    msg="{}, {}".format(res_jit, res_imperative),
532
                )
533

534
    def test_ipex_linear_add_relu(self):
535
        old = torch._C._debug_get_fusion_group_inlining()
536
        torch._C._debug_set_fusion_group_inlining(False)
537
        print("TEST linear+add+relu")
538
        rand_seed = int(get_rand_seed())
539
        torch.manual_seed(rand_seed)
540
        for bias in [True, False]:
541
            with torch.cpu.amp.autocast(
542
                enabled=True, dtype=torch.bfloat16
543
            ), torch.no_grad():
544
                te_model = IPEXLinearAddRelu(3, 32, bias).eval()
545
                x = torch.randn(3, 3)
546
                te_model_traced = torch.jit.trace(te_model, (x))
547
                te_model_traced = torch.jit.freeze(te_model_traced)
548
                te_model_traced(x)
549
                # self.assertAllFused(te_model_traced.graph_for(x))
550

551
                res_jit = te_model_traced(x)
552
                res_imperative = te_model(x)
553
                self.assertEqual(
554
                    res_jit,
555
                    res_imperative,
556
                    rtol=0.02,
557
                    atol=0.01,
558
                    msg="{}, {}".format(res_jit, res_imperative),
559
                )
560

561
                x = torch.randn(8, 3)
562
                res_jit = te_model_traced(x)
563
                res_imperative = te_model(x)
564
                self.assertEqual(
565
                    res_jit,
566
                    res_imperative,
567
                    rtol=0.02,
568
                    atol=0.01,
569
                    msg="{}, {}".format(res_jit, res_imperative),
570
                )
571

572
    def test_ipex_linear_sigmoid_mul(self):
573
        old = torch._C._debug_get_fusion_group_inlining()
574
        torch._C._debug_set_fusion_group_inlining(False)
575
        print("TEST linear+sigmoid+mul")
576
        rand_seed = int(get_rand_seed())
577
        torch.manual_seed(rand_seed)
578
        for bias in [True, False]:
579
            with torch.cpu.amp.autocast(
580
                enabled=True, dtype=torch.bfloat16
581
            ), torch.no_grad():
582
                te_model = IPEXLinearSigmoidMul(3, 32, bias).eval()
583
                x = torch.randn(3, 3)
584
                te_model_traced = torch.jit.trace(te_model, (x))
585
                te_model_traced = torch.jit.freeze(te_model_traced)
586
                te_model_traced(x)
587
                # self.assertAllFused(te_model_traced.graph_for(x))
588

589
                res_jit = te_model_traced(x)
590
                res_imperative = te_model(x)
591
                self.assertEqual(
592
                    res_jit,
593
                    res_imperative,
594
                    rtol=0.02,
595
                    atol=0.01,
596
                    msg="{}, {}".format(res_jit, res_imperative),
597
                )
598

599
                x = torch.randn(8, 3)
600
                res_jit = te_model_traced(x)
601
                res_imperative = te_model(x)
602
                self.assertEqual(
603
                    res_jit,
604
                    res_imperative,
605
                    rtol=0.02,
606
                    atol=0.01,
607
                    msg="{}, {}".format(res_jit, res_imperative),
608
                )
609

610

611
if __name__ == "__main__":
612
    # ipex._C.enable_custom_op_2_nnc_fuser()
613
    test = unittest.main()
614

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.