intel-extension-for-pytorch

Форк
0
/
test_runtime_api_jit.py 
965 строк · 33.7 Кб
1
import unittest
2
import torch
3
import intel_extension_for_pytorch as ipex
4
from torch.testing._internal.jit_utils import JitTestCase
5
from test_ao_jit_llga_utils import JitLlgaTestCase
6
from test_runtime_api import TestInputOutputModule
7
from common_ipex_conf import runtime_thread_affinity_test_env
8

9

10
class SimpleNet(torch.nn.Module):
11
    def __init__(self):
12
        super(SimpleNet, self).__init__()
13
        self.conv = torch.nn.Conv2d(
14
            64, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
15
        )
16

17
    def forward(self, x):
18
        x1 = self.conv(x)
19
        y = torch.flatten(x1, start_dim=1)
20
        return y
21

22

23
class SimpleNet_v2(torch.nn.Module):
24
    def __init__(self):
25
        super(SimpleNet_v2, self).__init__()
26
        self.conv = torch.nn.Conv2d(
27
            3, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
28
        )
29
        self.conv2 = torch.nn.Conv2d(
30
            64, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
31
        )
32

33
    def forward(self, x):
34
        x1 = self.conv(x)
35
        x1 = self.conv2(x1)
36
        y = torch.flatten(x1, start_dim=1)
37
        return y
38

39

40
class TestJitRuntimeAPI(JitTestCase):
41
    @unittest.skipIf(
42
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
43
        "Skip when IPEX Runtime extension is not enabled",
44
    )
45
    @runtime_thread_affinity_test_env
46
    def test_task_async_api_fp32_jit_model(self):
47
        model = SimpleNet()
48
        model.eval()
49
        x = torch.rand(64, 64, 3, 3)
50

51
        # Calculate the reference result
52
        trace_model = torch.jit.trace(model, x)
53
        y = trace_model(x)
54

55
        # Create task
56
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
57
        task = ipex.cpu.runtime.Task(trace_model, cpu_pool)
58

59
        # Task submit and get
60
        y_runtime_future = task(x)
61
        y_runtime = y_runtime_future.get()
62
        self.assertEqual(y, y_runtime)
63

64
    @unittest.skipIf(
65
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
66
        "Skip when IPEX Runtime extension is not enabled",
67
    )
68
    @runtime_thread_affinity_test_env
69
    def test_task_sync_api_fp32_jit_model(self):
70
        model = SimpleNet()
71
        model.eval()
72
        x = torch.rand(64, 64, 3, 3)
73

74
        # Calculate the reference result
75
        trace_mode = torch.jit.trace(model, x)
76
        y = trace_mode(x)
77

78
        # Create task
79
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
80
        task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)
81

82
        # Task sync run
83
        y_runtime = task.run_sync(x)
84
        self.assertEqual(y, y_runtime)
85

86
    @unittest.skipIf(
87
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
88
        "Skip when IPEX Runtime extension is not enabled",
89
    )
90
    @runtime_thread_affinity_test_env
91
    def test_task_async_api_bf16_jit_model(self):
92
        model = SimpleNet()
93
        model.eval()
94
        x = torch.rand(64, 64, 3, 3)
95

96
        # Calculate the reference result
97
        with torch.cpu.amp.autocast(
98
            enabled=True, dtype=torch.bfloat16
99
        ), torch.no_grad():
100
            trace_mode = torch.jit.trace(model, x)
101
        y = trace_mode(x)
102

103
        # Create task
104
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
105
        task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)
106

107
        # Task submit and wait
108
        y_runtime_future = task(x)
109
        y_runtime = y_runtime_future.get()
110
        self.assertEqual(y, y_runtime)
111

112
    @unittest.skipIf(
113
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
114
        "Skip when IPEX Runtime extension is not enabled",
115
    )
116
    @runtime_thread_affinity_test_env
117
    def test_task_async_api_bf16_jit_model_multi_submission(self):
118
        model = SimpleNet()
119
        model.eval()
120
        x = torch.rand(64, 64, 3, 3)
121

122
        # Calculate the reference result
123
        with torch.cpu.amp.autocast(
124
            enabled=True, dtype=torch.bfloat16
125
        ), torch.no_grad():
126
            trace_mode = torch.jit.trace(model, x)
127
        y = trace_mode(x)
128

129
        # Create task
130
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
131
        task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)
132

133
        # Submit task 3 times, then wait for result
134
        y_runtime = []
135
        y_runtime_future = []
136
        for i in range(3):
137
            y_runtime_future.append(task(x))
138
        for item in y_runtime_future:
139
            y_runtime.append(item.get())
140

141
        self.assertEqual(y, y_runtime[0])
142
        self.assertEqual(y, y_runtime[1])
143
        self.assertEqual(y, y_runtime[2])
144

145
    @unittest.skipIf(
146
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
147
        "Skip when IPEX Runtime extension is not enabled",
148
    )
149
    @runtime_thread_affinity_test_env
150
    def test_task_copy_bf16_jit_mode(self):
151
        model = SimpleNet()
152
        model.eval()
153
        x = torch.rand(64, 64, 3, 3)
154

155
        # Calculate the reference result
156
        with torch.cpu.amp.autocast(
157
            enabled=True, dtype=torch.bfloat16
158
        ), torch.no_grad():
159
            trace_mode = torch.jit.trace(model, x)
160
        y = trace_mode(x)
161

162
        # Create task
163
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
164
        task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)
165

166
        # Copy Task
167
        task2 = task
168

169
        # Task submit and wait
170
        y_runtime_future = task(x)
171
        y_runtime = y_runtime_future.get()
172
        y_runtime_future2 = task2(x)
173
        y_runtime2 = y_runtime_future2.get()
174
        self.assertEqual(y, y_runtime)
175
        self.assertEqual(y, y_runtime2)
176

177

178
class TestJITMultiStreamModule(JitTestCase):
179
    @unittest.skipIf(
180
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
181
        "Skip when IPEX Runtime extension is not enabled",
182
    )
183
    @runtime_thread_affinity_test_env
184
    def test_multi_stream_module_bf16_jit_model(self):
185
        model = SimpleNet()
186
        model.eval()
187
        cpu_pool = ipex.cpu.runtime.CPUPool()
188
        batch_size = cpu_pool.core_ids.__len__()
189
        x = torch.rand(batch_size, 64, 3, 3)
190
        num_streams = batch_size
191

192
        # Calculate the reference result
193
        with torch.cpu.amp.autocast(
194
            enabled=True, dtype=torch.bfloat16
195
        ), torch.no_grad():
196
            trace_model = torch.jit.trace(model, x)
197
        y = trace_model(x)
198

199
        # Create MultiStreamModule
200
        cpu_pool = ipex.cpu.runtime.CPUPool()
201
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
202
            trace_model, num_streams=num_streams, cpu_pool=cpu_pool
203
        )
204

205
        y_runtime = multi_stream_model(x)
206
        self.assertEqual(y, y_runtime)
207

208
    @unittest.skipIf(
209
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
210
        "Skip when IPEX Runtime extension is not enabled",
211
    )
212
    @runtime_thread_affinity_test_env
213
    def test_multi_stream_module_bf16_jit_model_concat_output(self):
214
        model = SimpleNet()
215
        model.eval()
216

217
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
218
        batch_size = cpu_pool.core_ids.__len__()
219
        x = torch.rand(batch_size, 64, 3, 3)
220
        num_streams = batch_size
221

222
        # Calculate the reference result
223
        with torch.cpu.amp.autocast(
224
            enabled=True, dtype=torch.bfloat16
225
        ), torch.no_grad():
226
            trace_model = torch.jit.trace(model, x)
227

228
        # Create MultiStreamModule
229
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
230
            trace_model, num_streams=num_streams, cpu_pool=cpu_pool
231
        )
232
        y_runtime = multi_stream_model(x)
233

234
        # Create MultiStreamModule with concat_output=False
235
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
236
            trace_model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False
237
        )
238
        y_runtime2 = multi_stream_model2(x)
239
        self.assertEqual(y_runtime2.__len__(), num_streams)
240
        self.assertEqual(y_runtime, torch.cat(y_runtime2))
241

242
    @unittest.skipIf(
243
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
244
        "Skip when IPEX Runtime extension is not enabled",
245
    )
246
    @runtime_thread_affinity_test_env
247
    def test_single_stream_module_bf16_jit_model(self):
248
        model = SimpleNet()
249
        model.eval()
250
        batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()
251
        x = torch.rand(batch_size, 64, 3, 3)
252

253
        # Calculate the reference result
254
        with torch.cpu.amp.autocast(
255
            enabled=True, dtype=torch.bfloat16
256
        ), torch.no_grad():
257
            trace_model = torch.jit.trace(model, x)
258

259
        y = trace_model(x)
260

261
        # Create MultiStreamModule
262
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
263
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
264
            trace_model, num_streams=1, cpu_pool=cpu_pool
265
        )
266
        y_runtime = multi_stream_model(x)
267

268
        # Create MultiStreamModule with concat_output=False
269
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
270
            trace_model, num_streams=1, cpu_pool=cpu_pool, concat_output=False
271
        )
272
        y_runtime2 = multi_stream_model2(x)
273
        self.assertEqual(y, y_runtime)
274
        self.assertEqual(y, y_runtime2[0])
275

276
    @unittest.skipIf(
277
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
278
        "Skip when IPEX Runtime extension is not enabled",
279
    )
280
    @runtime_thread_affinity_test_env
281
    def test_core_number_not_divisible_stream_number_bf16_jit_model(self):
282
        model = SimpleNet()
283
        model.eval()
284
        num_streams = 2
285
        batch_size = num_streams
286
        x = torch.rand(batch_size, 64, 3, 3)
287

288
        # Calculate the reference result
289
        with torch.cpu.amp.autocast(
290
            enabled=True, dtype=torch.bfloat16
291
        ), torch.no_grad():
292
            traced_model = torch.jit.trace(model, x)
293
        traced_model = torch.jit.freeze(traced_model)
294

295
        # Calculate the reference result
296
        y = traced_model(x)
297

298
        # Create MultiStreamModule
299
        # Core Number is 3, stream Number is 2
300
        cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
301
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
302
            traced_model, num_streams=num_streams, cpu_pool=cpu_pool
303
        )
304
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
305
            traced_model,
306
            num_streams=num_streams,
307
            cpu_pool=cpu_pool,
308
            concat_output=False,
309
        )
310

311
        y_runtime = multi_stream_model(x)
312
        y_runtime2 = multi_stream_model2(x)
313
        self.assertEqual(y, y_runtime)
314
        self.assertEqual(y, torch.cat(y_runtime2))
315

316
    @unittest.skipIf(
317
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
318
        "Skip when IPEX Runtime extension is not enabled",
319
    )
320
    @runtime_thread_affinity_test_env
321
    def test_batchsize_less_than_stream_number_bf16_jit_model(self):
322
        model = SimpleNet()
323
        model.eval()
324
        num_streams = 3
325
        batch_size = 2
326
        x = torch.rand(batch_size, 64, 3, 3)
327

328
        # Calculate the reference result
329
        with torch.cpu.amp.autocast(
330
            enabled=True, dtype=torch.bfloat16
331
        ), torch.no_grad():
332
            traced_model = torch.jit.trace(model, x)
333
        traced_model = torch.jit.freeze(traced_model)
334

335
        # Calculate the reference result
336
        y = traced_model(x)
337

338
        # Create MultiStreamModule
339
        # Batchsize 2, Core Number is 3, stream Number is 3
340
        cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
341
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
342
            traced_model, num_streams=num_streams, cpu_pool=cpu_pool
343
        )
344
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
345
            traced_model,
346
            num_streams=num_streams,
347
            cpu_pool=cpu_pool,
348
            concat_output=False,
349
        )
350

351
        y_runtime = multi_stream_model(x)
352
        y_runtime2 = multi_stream_model2(x)
353
        self.assertEqual(y, y_runtime)
354
        self.assertEqual(y, torch.cat(y_runtime2))
355

356
    @unittest.skipIf(
357
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
358
        "Skip when IPEX Runtime extension is not enabled",
359
    )
360
    @runtime_thread_affinity_test_env
361
    def test_batchsize_not_divisible_stream_number_bf16_jit_model(self):
362
        model = SimpleNet()
363
        model.eval()
364
        num_streams = 3
365
        batch_size = 4
366
        x = torch.rand(batch_size, 64, 3, 3)
367

368
        # Calculate the reference result
369
        with torch.cpu.amp.autocast(
370
            enabled=True, dtype=torch.bfloat16
371
        ), torch.no_grad():
372
            traced_model = torch.jit.trace(model, x)
373
        traced_model = torch.jit.freeze(traced_model)
374

375
        # Calculate the reference result
376
        y = traced_model(x)
377

378
        # Create MultiStreamModule
379
        # Batchsize 4, Core Number is 3, stream Number is 3
380
        cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
381
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
382
            traced_model, num_streams=num_streams, cpu_pool=cpu_pool
383
        )
384
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
385
            traced_model,
386
            num_streams=num_streams,
387
            cpu_pool=cpu_pool,
388
            concat_output=False,
389
        )
390

391
        y_runtime = multi_stream_model(x)
392
        y_runtime2 = multi_stream_model2(x)
393

394
        self.assertEqual(y, y_runtime)
395
        self.assertEqual(y, torch.cat(y_runtime2))
396
        self.assertEqual(y_runtime2[0].size(0), 2)
397
        self.assertEqual(y_runtime2[1].size(0), 1)
398
        self.assertEqual(y_runtime2[2].size(0), 1)
399

400
    @unittest.skipIf(
401
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
402
        "Skip when IPEX Runtime extension is not enabled",
403
    )
404
    @runtime_thread_affinity_test_env
405
    def test_stream_number_auto_bf16_jit_model(self):
406
        model = torch.nn.Softmax(dim=-1)
407
        model.eval()
408
        for i in range(ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()):
409
            batch_size = list(range(i + 1)).__len__()
410
            x = torch.rand(batch_size, 64)
411

412
            # Calculate the reference result
413
            with torch.cpu.amp.autocast(
414
                enabled=True, dtype=torch.bfloat16
415
            ), torch.no_grad():
416
                traced_model = torch.jit.trace(model, x)
417
            traced_model = torch.jit.freeze(traced_model)
418

419
            # Warm Up
420
            for _ in range(3):
421
                traced_model(x)
422

423
            # Calculate the reference result
424
            y = traced_model(x)
425

426
            cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=list(range(i + 1)))
427

428
            # The stream number will be determined automatically.
429
            multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
430
                traced_model, cpu_pool=cpu_pool
431
            )
432
            y_runtime = multi_stream_model(x)
433
            stream_num_ground_truth = ipex.cpu.runtime.get_default_num_streams(cpu_pool)
434
            self.assertEqual(y, y_runtime)
435
            self.assertEqual(
436
                multi_stream_model.get_stream_number(), stream_num_ground_truth
437
            )
438

439
    @unittest.skipIf(
440
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
441
        "Skip when IPEX Runtime extension is not enabled",
442
    )
443
    @runtime_thread_affinity_test_env
444
    def test_stream_number_larger_than_core_number(self):
445
        model = torch.nn.Softmax(dim=-1)
446
        model.eval()
447

448
        cpu_pool = ipex.cpu.runtime.CPUPool()
449
        batch_size = cpu_pool.core_ids.__len__()
450
        num_streams = batch_size + 1
451
        x = torch.rand(batch_size, 64)
452

453
        # Calculate the reference result
454
        with torch.cpu.amp.autocast(
455
            enabled=True, dtype=torch.bfloat16
456
        ), torch.no_grad():
457
            traced_model = torch.jit.trace(model, x)
458
        traced_model = torch.jit.freeze(traced_model)
459

460
        # Warm Up
461
        for _ in range(3):
462
            traced_model(x)
463

464
        # Calculate the reference result
465
        y = traced_model(x)
466

467
        # The stream number will be determined automatically.
468
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
469
            traced_model, num_streams=num_streams, cpu_pool=cpu_pool
470
        )
471
        y_runtime = multi_stream_model(x)
472
        stream_num_ground_truth = ipex.cpu.runtime.get_default_num_streams(cpu_pool)
473
        self.assertEqual(y, y_runtime)
474
        self.assertEqual(
475
            multi_stream_model.get_stream_number(), cpu_pool.core_ids.__len__()
476
        )
477

478

479
class TestLLGARuntimeAPI(JitLlgaTestCase):
480
    @unittest.skipIf(
481
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
482
        "Skip when IPEX Runtime extension is not enabled",
483
    )
484
    @runtime_thread_affinity_test_env
485
    def test_task_async_api_int8_jit_model(self):
486
        with torch.no_grad():
487
            model = SimpleNet_v2()
488
            model.eval()
489
            x = torch.rand(2, 3, 224, 224).contiguous(memory_format=torch.channels_last)
490

491
            # Calculate the reference result
492
            graph, m_llga, m_cpu = self.prepareModel(model, [x])
493
            y = m_llga(x)
494

495
            # Create task
496
            cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
497
            task = ipex.cpu.runtime.Task(m_llga, cpu_pool)
498

499
            # Task submit and wait
500
            y_runtime_future = task(x)
501
            y_runtime = y_runtime_future.get()
502
            self.assertEqual(y, y_runtime)
503

504
    @unittest.skipIf(
505
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
506
        "Skip when IPEX Runtime extension is not enabled",
507
    )
508
    @runtime_thread_affinity_test_env
509
    def test_multi_stream_module_int8_jit_model(self):
510
        with torch.no_grad():
511
            model = SimpleNet_v2()
512
            model.eval()
513
            x = torch.rand(2, 3, 224, 224).contiguous(memory_format=torch.channels_last)
514

515
            # Calculate the reference result
516
            graph, m_llga, m_cpu = self.prepareModel(model, [x])
517
            y = m_llga(x)
518

519
            # Create task
520
            cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
521
            multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
522
                m_llga, num_streams=1, cpu_pool=cpu_pool
523
            )
524
            multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
525
                m_llga, num_streams=1, cpu_pool=cpu_pool, concat_output=False
526
            )
527

528
            # Task submit and wait
529
            y_runtime = multi_stream_model(x)
530
            y_runtime2 = multi_stream_model2(x)
531
            self.assertEqual(y, y_runtime)
532
            self.assertEqual(y, torch.cat(y_runtime2))
533

534
    @unittest.skipIf(
535
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
536
        "Skip when IPEX Runtime extension is not enabled",
537
    )
538
    @runtime_thread_affinity_test_env
539
    def test_core_number_not_divisible_stream_number_int8_jit_model(self):
540
        with torch.no_grad():
541
            model = SimpleNet_v2()
542
            model.eval()
543
            num_streams = 2
544
            batch_size = num_streams
545
            x = torch.rand(batch_size, 3, 16, 16).contiguous(
546
                memory_format=torch.channels_last
547
            )
548

549
            # Calculate the reference result
550
            graph, m_llga, m_cpu = self.prepareModel(model, [x])
551
            y = m_llga(x)
552

553
            # Create MultiStreamModule
554
            # Core Number is 3, stream Number is 2
555
            cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
556
            multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
557
                m_llga, num_streams=num_streams, cpu_pool=cpu_pool
558
            )
559
            multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
560
                m_llga, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False
561
            )
562

563
            # Task submit and wait
564
            y_runtime = multi_stream_model(x)
565
            y_runtime2 = multi_stream_model2(x)
566
            self.assertEqual(y, y_runtime)
567
            self.assertEqual(y, torch.cat(y_runtime2))
568

569
    @unittest.skipIf(
570
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
571
        "Skip when IPEX Runtime extension is not enabled",
572
    )
573
    @runtime_thread_affinity_test_env
574
    def test_batchsize_less_than_stream_number_int8_jit_model(self):
575
        with torch.no_grad():
576
            model = SimpleNet_v2()
577
            model.eval()
578
            num_streams = 3
579
            batch_size = 2
580
            x = torch.rand(batch_size, 3, 16, 16).contiguous(
581
                memory_format=torch.channels_last
582
            )
583

584
            # Calculate the reference result
585
            graph, m_llga, m_cpu = self.prepareModel(model, [x])
586
            y = m_llga(x)
587

588
            # Create MultiStreamModule
589
            # Batchsize is 2, Core Number is 3, stream Number is 3
590
            cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
591
            multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
592
                m_llga, num_streams=num_streams, cpu_pool=cpu_pool
593
            )
594
            multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
595
                m_llga, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False
596
            )
597

598
            # Task submit and wait
599
            y_runtime = multi_stream_model(x)
600
            y_runtime2 = multi_stream_model2(x)
601
            self.assertEqual(y, y_runtime)
602
            self.assertEqual(y, torch.cat(y_runtime2))
603
            self.assertEqual(y_runtime2.__len__(), batch_size)
604

605

606
class TestMultiStreamModuleHint(JitTestCase):
607
    def init_set_up(self):
608
        # Create Multi Stream Module without concat output
609
        cpu_pool = ipex.cpu.runtime.CPUPool()
610
        batch_size = cpu_pool.core_ids.__len__()
611
        num_streams = cpu_pool.core_ids.__len__()
612
        return batch_size, num_streams, cpu_pool
613

614
    def create_jit_traced_model(self, model, input):
615
        traced_model = torch.jit.trace(model, input).eval()
616
        traced_model = torch.jit.freeze(traced_model)
617
        return traced_model
618

619
    def create_multi_stream_module(
620
        self,
621
        traced_model,
622
        num_streams,
623
        cpu_pool,
624
        multi_stream_input_hint,
625
        multi_stream_output_hint=None,
626
        concat_output=True,
627
    ):
628
        if not concat_output:
629
            return ipex.cpu.runtime.MultiStreamModule(
630
                traced_model,
631
                num_streams=num_streams,
632
                cpu_pool=cpu_pool,
633
                concat_output=False,
634
                input_split_hint=multi_stream_input_hint,
635
            )
636

637
        else:
638
            return ipex.cpu.runtime.MultiStreamModule(
639
                traced_model,
640
                num_streams=num_streams,
641
                cpu_pool=cpu_pool,
642
                input_split_hint=multi_stream_input_hint,
643
                output_concat_hint=multi_stream_output_hint,
644
            )
645

646
    @unittest.skipIf(
647
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
648
        "Skip when IPEX Runtime extension is not enabled",
649
    )
650
    @runtime_thread_affinity_test_env
651
    def test_input_output_hint(self):
652
        batch_size, num_streams, cpu_pool = self.init_set_up()
653

654
        # This module:
655
        #   * Accept 3 tensors as input
656
        #   * Return a tuple of 3 tensors as output
657
        model = TestInputOutputModule().eval()
658
        for batch_size in (num_streams - 1, num_streams):
659
            # There is test for when batch_size is less than num_streams
660
            input_tensor1 = torch.rand(batch_size, 1)
661
            input_tensor2 = torch.rand(batch_size, 1)
662
            input_tensor3 = torch.rand(batch_size, 1)
663

664
            # Since jit trace only accept single tensor or a tuple of tensors as input
665
            # https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch-jit-trace
666
            jit_input = (input_tensor1, input_tensor2, input_tensor3)
667

668
            traced_model = self.create_jit_traced_model(model, jit_input)
669

670
            # Warm Up in the main thread to finish the jit pass optimizations
671
            for _ in range(3):
672
                traced_model(input_tensor1, input_tensor2, input_tensor3)
673

674
            # Calculate the reference result
675
            y_ref = traced_model(input_tensor1, input_tensor2, input_tensor3)
676

677
            multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, 0, 0)
678

679
            multi_stream_model = self.create_multi_stream_module(
680
                traced_model,
681
                num_streams,
682
                cpu_pool,
683
                multi_stream_input_hint,
684
                concat_output=False,
685
            )
686
            y_runtime = multi_stream_model(input_tensor1, input_tensor2, input_tensor3)
687

688
            # Manually concat the output
689
            y_runtime_res1 = []
690
            y_runtime_res2 = []
691
            y_runtime_res3 = []
692
            for stream_id in range(
693
                num_streams if ((batch_size // num_streams) >= 1) else batch_size
694
            ):
695
                y_runtime_res1.append(y_runtime[stream_id][0])
696
                y_runtime_res2.append(y_runtime[stream_id][1])
697
                y_runtime_res3.append(y_runtime[stream_id][2])
698
            y_runtime_res = (
699
                torch.cat(y_runtime_res1),
700
                torch.cat(y_runtime_res2),
701
                torch.cat(y_runtime_res3),
702
            )
703
            self.assertEqual(y_ref, y_runtime_res)
704

705
            # Create Multi Stream Module with concat output
706
            multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))
707

708
            multi_stream_model2 = self.create_multi_stream_module(
709
                traced_model,
710
                num_streams,
711
                cpu_pool,
712
                multi_stream_input_hint,
713
                multi_stream_output_hint,
714
                concat_output=True,
715
            )
716
            y_runtime_res2 = multi_stream_model2(
717
                input_tensor1, input_tensor2, input_tensor3
718
            )
719
            self.assertEqual(y_ref, y_runtime_res2)
720

721
    @unittest.skipIf(
722
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
723
        "Skip when IPEX Runtime extension is not enabled",
724
    )
725
    @runtime_thread_affinity_test_env
726
    def test_simulate_bert_large_input_output(self):
727
        class TestModule(torch.nn.Module):
728
            def __init__(self):
729
                super(TestModule, self).__init__()
730

731
            def forward(self, key1, key2, key3):
732
                return key1 * 2, key2 * 2
733

734
        # This module simulates the behaviour of Bert Large LZ models:
735
        #   * Accept 3 tensors (with key word) as input
736
        #   * Return a tuple of 2 tensors as output
737
        model = TestModule().eval()
738

739
        batch_size, num_streams, cpu_pool = self.init_set_up()
740
        jit_input = (
741
            torch.rand(batch_size, 1),
742
            torch.rand(batch_size, 2),
743
            torch.rand(batch_size, 3),
744
        )
745
        traced_model = self.create_jit_traced_model(model, jit_input)
746

747
        input_tensor1 = torch.rand(batch_size, 1)
748
        input_tensor2 = torch.rand(batch_size, 1)
749
        input_tensor3 = torch.rand(batch_size, 1)
750

751
        # Warm Up
752
        for _ in range(3):
753
            traced_model(key1=input_tensor1, key2=input_tensor2, key3=input_tensor3)
754

755
        # Calculate the reference result
756
        y_ref = traced_model(key1=input_tensor1, key2=input_tensor2, key3=input_tensor3)
757

758
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(
759
            key1=0, key2=0, key3=0
760
        )
761

762
        multi_stream_model = self.create_multi_stream_module(
763
            traced_model,
764
            num_streams,
765
            cpu_pool,
766
            multi_stream_input_hint,
767
            concat_output=False,
768
        )
769
        y_runtime = multi_stream_model(
770
            key1=input_tensor1, key2=input_tensor2, key3=input_tensor3
771
        )
772

773
        # Manually Concat the output
774
        y_runtime_res1 = []
775
        y_runtime_res2 = []
776
        for i in range(num_streams):
777
            y_runtime_res1.append(y_runtime[i][0])
778
            y_runtime_res2.append(y_runtime[i][1])
779
        y_runtime_res = (torch.cat(y_runtime_res1), torch.cat(y_runtime_res2))
780
        self.assertEqual(y_ref, y_runtime_res)
781

782
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0))
783

784
        multi_stream_model2 = self.create_multi_stream_module(
785
            traced_model,
786
            num_streams,
787
            cpu_pool,
788
            multi_stream_input_hint,
789
            multi_stream_output_hint,
790
            concat_output=True,
791
        )
792
        y_runtime_res2 = multi_stream_model2(
793
            key1=input_tensor1, key2=input_tensor2, key3=input_tensor3
794
        )
795

796
        self.assertEqual(y_ref, y_runtime_res2)
797

798
    @unittest.skipIf(
799
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
800
        "Skip when IPEX Runtime extension is not enabled",
801
    )
802
    @runtime_thread_affinity_test_env
803
    def test_mix_position_keyword_input_output_hint(self):
804
        class TestModule(torch.nn.Module):
805
            def __init__(self):
806
                super(TestModule, self).__init__()
807

808
            def forward(self, param1, param2, key1=None):
809
                return param1, param2, key1
810

811
        batch_size, num_streams, cpu_pool = self.init_set_up()
812
        # This module simulates the behaviour of Bert Large LZ models:
813
        #   * Accept 3 tensors (2 position parameter and 1 key word parameter) as input
814
        #   * Return a tuple of 3 tensors as output
815
        model = TestModule().eval()
816

817
        jit_input = (
818
            torch.rand(batch_size, 1),
819
            torch.rand(batch_size, 2),
820
            torch.rand(batch_size, 3),
821
        )
822

823
        traced_model = self.create_jit_traced_model(model, jit_input)
824

825
        input_tensor1 = torch.rand(batch_size, 1)
826
        input_tensor2 = torch.rand(batch_size, 2)
827
        input_tensor3 = torch.rand(batch_size, 3)
828
        input = (input_tensor1, input_tensor2)
829
        k_input = {"key1": input_tensor3}
830

831
        # Warm Up
832
        for _ in range(3):
833
            traced_model(input_tensor1, input_tensor2, key1=input_tensor3)
834

835
        # Calculate the reference result
836
        y_ref = traced_model(*input, **k_input)
837
        y_ref2 = traced_model(input_tensor1, input_tensor2, input_tensor3)
838
        y_ref3 = traced_model(input_tensor1, input_tensor2, key1=input_tensor3)
839
        self.assertEqual(y_ref, y_ref2)
840
        self.assertEqual(y_ref, y_ref3)
841

842
        # Be careful, jit traced model will change the input type
843
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, 0, key1=0)
844

845
        # Create Multi Stream Module with concat output
846
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))
847

848
        multi_stream_model = self.create_multi_stream_module(
849
            traced_model,
850
            num_streams,
851
            cpu_pool,
852
            multi_stream_input_hint,
853
            multi_stream_output_hint,
854
            concat_output=True,
855
        )
856
        # There are 2 ways to write now
857
        y_runtime_res = multi_stream_model(
858
            input_tensor1, input_tensor2, key1=input_tensor3
859
        )
860
        y_runtime_res2 = multi_stream_model(*input, **k_input)
861
        self.assertEqual(y_ref, y_runtime_res)
862
        self.assertEqual(y_ref, y_runtime_res2)
863

864
    @unittest.skipIf(
865
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
866
        "Skip when IPEX Runtime extension is not enabled",
867
    )
868
    @runtime_thread_affinity_test_env
869
    def test_input_output_hint_not_along_dim_zero(self):
870
        batch_size, num_streams, cpu_pool = self.init_set_up()
871

872
        # This module:
873
        #   * Accept 3 tensors as input
874
        #   * Return a tuple of 3 tensors as output
875
        model = TestInputOutputModule().eval()
876

877
        input_tensor1 = torch.rand(1, batch_size)
878
        input_tensor2 = torch.rand(batch_size, 2)
879
        input_tensor3 = torch.rand(3, batch_size)
880

881
        # Since jit trace only accept single tensor or a tuple of tensors as input
882
        # https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch-jit-trace
883
        jit_input = (input_tensor1, input_tensor2, input_tensor3)
884

885
        traced_model = self.create_jit_traced_model(model, jit_input)
886

887
        # Warm Up in the main thread to finish the jit pass optimizations
888
        for _ in range(3):
889
            traced_model(input_tensor1, input_tensor2, input_tensor3)
890

891
        # Calculate the reference result
892
        y_ref = traced_model(input_tensor1, input_tensor2, input_tensor3)
893

894
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(1, 0, 1)
895

896
        multi_stream_model = self.create_multi_stream_module(
897
            traced_model,
898
            num_streams,
899
            cpu_pool,
900
            multi_stream_input_hint,
901
            concat_output=False,
902
        )
903
        y_runtime = multi_stream_model(input_tensor1, input_tensor2, input_tensor3)
904

905
        # Manually concat the output
906
        y_runtime_res1 = []
907
        y_runtime_res2 = []
908
        y_runtime_res3 = []
909
        for stream_id in range(
910
            num_streams if ((batch_size // num_streams) >= 1) else batch_size
911
        ):
912
            y_runtime_res1.append(y_runtime[stream_id][0])
913
            y_runtime_res2.append(y_runtime[stream_id][1])
914
            y_runtime_res3.append(y_runtime[stream_id][2])
915
        y_runtime_res = (
916
            torch.cat(y_runtime_res1, 1),
917
            torch.cat(y_runtime_res2, 0),
918
            torch.cat(y_runtime_res3, 1),
919
        )
920
        self.assertEqual(y_ref, y_runtime_res)
921

922
        # Create Multi Stream Module with concat output
923
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((1, 0, 1))
924

925
        multi_stream_model2 = self.create_multi_stream_module(
926
            traced_model,
927
            num_streams,
928
            cpu_pool,
929
            multi_stream_input_hint,
930
            multi_stream_output_hint,
931
            concat_output=True,
932
        )
933
        y_runtime_res2 = multi_stream_model2(
934
            input_tensor1, input_tensor2, input_tensor3
935
        )
936
        self.assertEqual(y_ref, y_runtime_res2)
937

938

939
class TestMultiStreamBenchmarkModule(JitTestCase):
940
    @unittest.skipIf(
941
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
942
        "Skip when IPEX Runtime extension is not enabled",
943
    )
944
    @runtime_thread_affinity_test_env
945
    def test_multi_stream_benchmark_module_bf16_jit_model(self):
946
        model = SimpleNet().eval()
947
        batch_size = 1
948
        x = torch.rand(batch_size, 64, 3, 3)
949

950
        # Calculate the reference result
951
        with torch.cpu.amp.autocast(
952
            enabled=True, dtype=torch.bfloat16
953
        ), torch.no_grad():
954
            trace_model = torch.jit.trace(model, x)
955
        # Warm Up
956
        for _ in range(3):
957
            trace_model(x)
958

959
        # Create MultiStreamModule
960
        multi_stream_model = ipex.cpu.runtime._MultiStreamBenchmarkModule(trace_model)
961
        multi_stream_model(x)
962

963

964
if __name__ == "__main__":
965
    test = unittest.main()
966

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.