intel-extension-for-pytorch

Форк
0
/
test_runtime_api.py 
689 строк · 23.1 Кб
1
import unittest
2
import torch
3
import intel_extension_for_pytorch as ipex
4
from common_utils import TestCase
5

6
from common_ipex_conf import runtime_thread_affinity_test_env
7
import subprocess
8
import os
9

10

11
class SimpleNet(torch.nn.Module):
12
    def __init__(self):
13
        super(SimpleNet, self).__init__()
14
        self.conv = torch.nn.Conv2d(
15
            64, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
16
        )
17

18
    def forward(self, x):
19
        x1 = self.conv(x)
20
        y = torch.flatten(x1, start_dim=1)
21
        return y
22

23

24
class SimpleNet_v2(torch.nn.Module):
25
    def __init__(self):
26
        super(SimpleNet_v2, self).__init__()
27
        self.conv = torch.nn.Conv2d(
28
            3, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
29
        )
30
        self.conv2 = torch.nn.Conv2d(
31
            64, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
32
        )
33

34
    def forward(self, x):
35
        x1 = self.conv(x)
36
        x1 = self.conv2(x1)
37
        y = torch.flatten(x1, start_dim=1)
38
        return y
39

40

41
class SimpleNet_dict(torch.nn.Module):
42
    def __init__(self):
43
        super(SimpleNet_dict, self).__init__()
44
        self.conv = torch.nn.Conv2d(
45
            64, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
46
        )
47

48
    def forward(self, **x_dict):
49
        x1 = self.conv(x_dict["x1"])
50
        x2 = self.conv(x_dict["x2"])
51
        y1 = x1 + x2
52
        y2 = torch.flatten(y1, start_dim=1)
53
        ret_dict = {"y1": y1, "y2": y2}
54
        return ret_dict
55

56

57
class SimpleNet_tensor_dict(torch.nn.Module):
58
    def __init__(self):
59
        super(SimpleNet_tensor_dict, self).__init__()
60
        self.conv = torch.nn.Conv2d(
61
            64, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False
62
        )
63

64
    def forward(self, **x_dict):
65
        x1 = self.conv(x_dict["x1"])
66
        x2 = self.conv(x_dict["x2"])
67
        y1 = x1 + x2
68
        y2 = torch.flatten(y1, start_dim=1)
69
        ret_dict = {"y1": y1, "y2": y2}
70
        # Return a tuple of (Tensor, dict)
71
        return y1, ret_dict
72

73

74
class TestInputOutputModule(torch.nn.Module):
75
    def __init__(self):
76
        super(TestInputOutputModule, self).__init__()
77

78
    def forward(self, *args, **kwargs):
79
        return args
80

81

82
class TestInputOutputModule2(torch.nn.Module):
83
    def __init__(self):
84
        super(TestInputOutputModule2, self).__init__()
85

86
    def forward(self, param1):
87
        return param1
88

89

90
class TestCPUPool(TestCase):
91
    @unittest.skipIf(
92
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
93
        "Skip when IPEX Runtime extension is not enabled",
94
    )
95
    def test_cpupool_get_core_list(self):
96
        core_list = [0, 1]
97
        cpu_pool = ipex.cpu.runtime.CPUPool(core_list)
98
        self.assertEqual(cpu_pool.cpu_pool.get_core_list(), core_list)
99

100

101
class TestCoreBinding(TestCase):
102
    @unittest.skipIf(
103
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
104
        "Skip when IPEX Runtime extension is not enabled",
105
    )
106
    @runtime_thread_affinity_test_env
107
    def test_decorator_imperative_model(self):
108
        model = SimpleNet()
109
        model.eval()
110
        x = torch.rand(64, 64, 3, 3)
111
        cpu_pool = ipex.cpu.runtime.CPUPool([1, 2, 3, 4])
112

113
        @ipex.cpu.runtime.pin(cpu_pool)
114
        def test(model, x):
115
            return model(x)
116

117
        y_runtime = test(model, x)
118
        y = model(x)
119
        self.assertEqual(y, y_runtime)
120

121
    @unittest.skipIf(
122
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
123
        "Skip when IPEX Runtime extension is not enabled",
124
    )
125
    @runtime_thread_affinity_test_env
126
    def test_with_context_imperative_model(self):
127
        model = SimpleNet()
128
        model.eval()
129
        x = torch.rand(64, 64, 3, 3)
130
        cpu_pool = ipex.cpu.runtime.CPUPool([1, 2, 3, 4])
131
        with ipex.cpu.runtime.pin(cpu_pool):
132
            y_runtime = model(x)
133
        y = model(x)
134
        self.assertEqual(y, y_runtime)
135

136
    @unittest.skipIf(
137
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
138
        "Skip when IPEX Runtime extension is not enabled",
139
    )
140
    @runtime_thread_affinity_test_env
141
    def test_nested_with_context_imperative_model(self):
142
        model = torch.nn.Softmax(dim=-1)
143
        model.eval()
144
        x = torch.rand(100, 8276)
145
        cpu_pool = ipex.cpu.runtime.CPUPool([1, 2])
146
        cpu_pool2 = ipex.cpu.runtime.CPUPool([3, 4])
147
        with ipex.cpu.runtime.pin(cpu_pool):
148
            y_runtime = model(x)
149
            with ipex.cpu.runtime.pin(cpu_pool2):
150
                y_runtime = model(x)
151
            y_runtime = model(x)
152
        y = model(x)
153
        self.assertEqual(y, y_runtime)
154

155

156
class TestRuntimeAPI(TestCase):
157
    @unittest.skipIf(
158
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
159
        "Skip when IPEX Runtime extension is not enabled",
160
    )
161
    @runtime_thread_affinity_test_env
162
    def test_task_async_api_imperative_model(self):
163
        model = SimpleNet()
164
        model.eval()
165
        x = torch.rand(64, 64, 3, 3)
166
        # Calculate the reference result
167
        y = model(x)
168

169
        # Create task
170
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
171
        task = ipex.cpu.runtime.Task(model, cpu_pool)
172

173
        # Task submit and wait
174
        y_runtime_future = task(x)
175
        y_runtime = y_runtime_future.get()
176
        self.assertEqual(y, y_runtime)
177

178
    @unittest.skipIf(
179
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
180
        "Skip when IPEX Runtime extension is not enabled",
181
    )
182
    @runtime_thread_affinity_test_env
183
    def test_task_sync_api_imperative_model(self):
184
        model = SimpleNet()
185
        model.eval()
186
        x = torch.rand(64, 64, 3, 3)
187
        # Calculate the reference result
188
        y = model(x)
189

190
        # Create task
191
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
192
        task = ipex.cpu.runtime.Task(model, cpu_pool)
193

194
        # Task sync submit
195
        y_runtime = task.run_sync(x)
196
        self.assertEqual(y, y_runtime)
197

198
    @unittest.skipIf(
199
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
200
        "Skip when IPEX Runtime extension is not enabled",
201
    )
202
    @runtime_thread_affinity_test_env
203
    def test_task_async_api_native_function(self):
204
        model = SimpleNet()
205
        model.eval()
206
        x = torch.rand(64, 64, 3, 3)
207

208
        def test(model, x):
209
            return model(x)
210

211
        # Calculate the reference result
212
        y = test(model, x)
213

214
        # Create task
215
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
216
        task = ipex.cpu.runtime.Task(test, cpu_pool)
217

218
        # Task submit and wait
219
        y_runtime_future = task(model, x)
220
        y_runtime = y_runtime_future.get()
221
        self.assertEqual(y, y_runtime)
222

223
    @unittest.skipIf(
224
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
225
        "Skip when IPEX Runtime extension is not enabled",
226
    )
227
    @runtime_thread_affinity_test_env
228
    def test_task_copy(self):
229
        model = SimpleNet()
230
        model.eval()
231
        x = torch.rand(64, 64, 3, 3)
232
        # Calculate the reference result
233
        y = model(x)
234

235
        # Create task
236
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
237
        task = ipex.cpu.runtime.Task(model, cpu_pool)
238

239
        # Copy task
240
        task2 = task
241

242
        # Task submit and wait
243
        y_runtime_future = task(x)
244
        y_runtime = y_runtime_future.get()
245
        y_runtime_future2 = task2(x)
246
        y_runtime2 = y_runtime_future2.get()
247
        self.assertEqual(y, y_runtime)
248
        self.assertEqual(y, y_runtime2)
249

250

251
class TestMultiStreamModule(TestCase):
252
    @unittest.skipIf(
253
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
254
        "Skip when IPEX Runtime extension is not enabled",
255
    )
256
    @runtime_thread_affinity_test_env
257
    def test_multi_stream_module(self):
258
        model = SimpleNet()
259
        model.eval()
260
        batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()
261
        x = torch.rand(batch_size, 64, 3, 3)
262

263
        # Calculate the reference result
264
        y = model(x)
265

266
        # Create MultiStreamModule
267
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
268
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
269
            model, num_streams=2, cpu_pool=cpu_pool
270
        )
271

272
        y_runtime = multi_stream_model(x)
273
        self.assertEqual(y, y_runtime)
274

275
    @unittest.skipIf(
276
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
277
        "Skip when IPEX Runtime extension is not enabled",
278
    )
279
    @runtime_thread_affinity_test_env
280
    def test_multi_stream_module_with_dict_return_type(self):
281
        model = SimpleNet_dict()
282
        model.eval()
283
        batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()
284
        x1 = torch.rand(batch_size, 64, 3, 3)
285
        x2 = torch.rand(batch_size, 64, 3, 3)
286
        x_dict = {"x1": x1, "x2": x2}
287

288
        # Calculate the reference result
289
        y_dict = model(**x_dict)
290

291
        # Create MultiStreamModule
292
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
293

294
        input_hint_object = {"x1": 0, "x2": 0}
295
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(
296
            **input_hint_object
297
        )
298
        output_concat_object = {"y1": 0, "y2": 0}
299
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(
300
            **output_concat_object
301
        )
302

303
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
304
            model,
305
            num_streams=2,
306
            cpu_pool=cpu_pool,
307
            input_split_hint=multi_stream_input_hint,
308
            output_concat_hint=multi_stream_output_hint,
309
        )
310

311
        y_runtime_dict = multi_stream_model(**x_dict)
312
        self.assertEqual(y_dict["y1"], y_runtime_dict["y1"])
313
        self.assertEqual(y_dict["y2"], y_runtime_dict["y2"])
314

315
    @unittest.skipIf(
316
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
317
        "Skip when IPEX Runtime extension is not enabled",
318
    )
319
    @runtime_thread_affinity_test_env
320
    def test_multi_stream_module_with_tensor_and_dict_return_type(self):
321
        model = SimpleNet_tensor_dict()
322
        model.eval()
323
        batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()
324
        x1 = torch.rand(batch_size, 64, 3, 3)
325
        x2 = torch.rand(batch_size, 64, 3, 3)
326
        x_dict = {"x1": x1, "x2": x2}
327

328
        # Calculate the reference result
329
        y, y_dict = model(**x_dict)
330

331
        # Create MultiStreamModule
332
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
333

334
        input_hint_object = {"x1": 0, "x2": 0}
335
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(
336
            **input_hint_object
337
        )
338
        output_concat_object = (0, {"y1": 0, "y2": 0})
339
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(
340
            output_concat_object
341
        )
342

343
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
344
            model,
345
            num_streams=2,
346
            cpu_pool=cpu_pool,
347
            input_split_hint=multi_stream_input_hint,
348
            output_concat_hint=multi_stream_output_hint,
349
        )
350

351
        y_runtime, y_runtime_dict = multi_stream_model(**x_dict)
352
        self.assertEqual(y, y_runtime)
353
        self.assertEqual(y_dict["y1"], y_runtime_dict["y1"])
354
        self.assertEqual(y_dict["y2"], y_runtime_dict["y2"])
355

356
    @unittest.skipIf(
357
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
358
        "Skip when IPEX Runtime extension is not enabled",
359
    )
360
    @runtime_thread_affinity_test_env
361
    def test_single_stream_module(self):
362
        model = SimpleNet()
363
        model.eval()
364
        batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()
365
        x = torch.rand(batch_size, 64, 3, 3)
366

367
        # Calculate the reference result
368
        y = model(x)
369

370
        # Create MultiStreamModule
371
        cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
372
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
373
            model, num_streams=1, cpu_pool=cpu_pool
374
        )
375
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
376
            model, num_streams=1, cpu_pool=cpu_pool, concat_output=False
377
        )
378

379
        y_runtime = multi_stream_model(x)
380
        y_runtime2 = multi_stream_model2(x)
381
        self.assertEqual(y, y_runtime)
382
        self.assertEqual(y, y_runtime2[0])
383

384
    @unittest.skipIf(
385
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
386
        "Skip when IPEX Runtime extension is not enabled",
387
    )
388
    @runtime_thread_affinity_test_env
389
    def test_core_number_not_divisible_by_stream_number(self):
390
        model = SimpleNet()
391
        model.eval()
392
        num_streams = 2
393
        batch_size = num_streams
394
        x = torch.rand(batch_size, 64, 3, 3)
395
        # Calculate the reference result
396
        y = model(x)
397

398
        # Create MultiStreamModule
399
        # Core Number is 3, stream Number is 2
400
        cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
401
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
402
            model, num_streams=num_streams, cpu_pool=cpu_pool
403
        )
404
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
405
            model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False
406
        )
407

408
        y_runtime = multi_stream_model(x)
409
        y_runtime2 = multi_stream_model2(x)
410
        self.assertEqual(y, y_runtime)
411
        self.assertEqual(y, torch.cat(y_runtime2))
412

413
    @unittest.skipIf(
414
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
415
        "Skip when IPEX Runtime extension is not enabled",
416
    )
417
    @runtime_thread_affinity_test_env
418
    def test_batchsize_less_than_stream_number(self):
419
        model = SimpleNet()
420
        model.eval()
421
        num_streams = 3
422
        batch_size = 2
423
        x = torch.rand(batch_size, 64, 3, 3)
424
        # Calculate the reference result
425
        y = model(x)
426

427
        # Create MultiStreamModule
428
        # Batchsize 2, Core Number is 3, stream Number is 3
429
        cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
430
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
431
            model, num_streams=num_streams, cpu_pool=cpu_pool
432
        )
433
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
434
            model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False
435
        )
436

437
        y_runtime = multi_stream_model(x)
438
        y_runtime2 = multi_stream_model2(x)
439
        self.assertEqual(y, y_runtime)
440
        self.assertEqual(y, torch.cat(y_runtime2))
441

442
    @unittest.skipIf(
443
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
444
        "Skip when IPEX Runtime extension is not enabled",
445
    )
446
    @runtime_thread_affinity_test_env
447
    def test_batchsize_not_divisible_by_stream_number(self):
448
        model = SimpleNet()
449
        model.eval()
450
        num_streams = 3
451
        batch_size = 4
452
        x = torch.rand(batch_size, 64, 3, 3)
453
        # Calculate the reference result
454
        y = model(x)
455

456
        # Create MultiStreamModule
457
        # Batchsize 4, Core Number is 3, stream Number is 3
458
        cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])
459
        multi_stream_model = ipex.cpu.runtime.MultiStreamModule(
460
            model, num_streams=num_streams, cpu_pool=cpu_pool
461
        )
462
        multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(
463
            model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False
464
        )
465

466
        y_runtime = multi_stream_model(x)
467
        y_runtime2 = multi_stream_model2(x)
468
        self.assertEqual(y, y_runtime)
469
        self.assertEqual(y, torch.cat(y_runtime2))
470
        self.assertEqual(y_runtime2[0].size(0), 2)
471
        self.assertEqual(y_runtime2[1].size(0), 1)
472
        self.assertEqual(y_runtime2[2].size(0), 1)
473

474

475
class TestModuleMultiStreamModuleHint(TestCase):
476
    # For the inputs format which can't be jit.trace
477
    def init_set_up(self):
478
        # Create Multi Stream Module without concat output
479
        cpu_pool = ipex.cpu.runtime.CPUPool()
480
        batch_size = cpu_pool.core_ids.__len__()
481
        num_streams = cpu_pool.core_ids.__len__()
482
        return batch_size, num_streams, cpu_pool
483

484
    def create_multi_stream_module(
485
        self,
486
        traced_model,
487
        num_streams,
488
        cpu_pool,
489
        multi_stream_input_hint,
490
        multi_stream_output_hint=None,
491
        concat_output=True,
492
    ):
493
        if not concat_output:
494
            return ipex.cpu.runtime.MultiStreamModule(
495
                traced_model,
496
                num_streams=num_streams,
497
                cpu_pool=cpu_pool,
498
                concat_output=False,
499
                input_split_hint=multi_stream_input_hint,
500
            )
501
        else:
502
            return ipex.cpu.runtime.MultiStreamModule(
503
                traced_model,
504
                num_streams=num_streams,
505
                cpu_pool=cpu_pool,
506
                input_split_hint=multi_stream_input_hint,
507
                output_concat_hint=multi_stream_output_hint,
508
            )
509

510
    @unittest.skipIf(
511
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
512
        "Skip when IPEX Runtime extension is not enabled",
513
    )
514
    @runtime_thread_affinity_test_env
515
    def test_mix_tensor_bool_input_output_hint(self):
516
        # This module:
517
        #   * Accept 2 tensors + 1 scalar as input
518
        #   * Return 2 tensors + 1 scalar as output
519
        # Since Type 'Tuple[Tensor, bool, Tensor]' cannot be traced, we put this test input type in imperative mode.
520
        model = TestInputOutputModule().eval()
521
        batch_size, num_streams, cpu_pool = self.init_set_up()
522

523
        input_tensor1 = torch.rand(batch_size, 1)
524
        input_tensor2 = torch.rand(batch_size, 3)
525

526
        # Calculate the reference result
527
        y_ref = model(input_tensor1, False, input_tensor2)
528

529
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, None, 0)
530
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, None, 0))
531

532
        multi_stream_model = self.create_multi_stream_module(
533
            model,
534
            num_streams,
535
            cpu_pool,
536
            multi_stream_input_hint,
537
            multi_stream_output_hint,
538
            concat_output=True,
539
        )
540
        y_runtime_res = multi_stream_model(input_tensor1, False, input_tensor2)
541
        self.assertEqual(y_ref, y_runtime_res)
542

543
    @unittest.skipIf(
544
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
545
        "Skip when IPEX Runtime extension is not enabled",
546
    )
547
    @runtime_thread_affinity_test_env
548
    def test_tuple_input_output_hint(self):
549
        # This module:
550
        #   * Accept 1 tuple(3 tensors) as input
551
        #   * Return 1 tuple(3 tensors) as output
552
        model = TestInputOutputModule2().eval()
553
        batch_size, num_streams, cpu_pool = self.init_set_up()
554

555
        input_tensor1 = torch.rand(batch_size, 1)
556
        input_tensor2 = torch.rand(batch_size, 2)
557
        input_tensor3 = torch.rand(batch_size, 3)
558
        input = (input_tensor1, input_tensor2, input_tensor3)
559
        y_ref = model(input)
560

561
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))
562
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))
563

564
        multi_stream_model = self.create_multi_stream_module(
565
            model,
566
            num_streams,
567
            cpu_pool,
568
            multi_stream_input_hint,
569
            multi_stream_output_hint,
570
            concat_output=True,
571
        )
572
        y_runtime_res = multi_stream_model(input)
573
        self.assertEqual(y_ref, y_runtime_res)
574

575
    @unittest.skipIf(
576
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
577
        "Skip when IPEX Runtime extension is not enabled",
578
    )
579
    @runtime_thread_affinity_test_env
580
    def test_dict_input_output_hint(self):
581
        # This module:
582
        #   * Accept 1 dict(3 tensors) as input
583
        #   * Return 1 dict(3 tensors) as output
584
        model = TestInputOutputModule2().eval()
585
        batch_size, num_streams, cpu_pool = self.init_set_up()
586

587
        input_tensor1 = torch.rand(batch_size, 1)
588
        input_tensor2 = torch.rand(batch_size, 2)
589
        input_tensor3 = torch.rand(batch_size, 3)
590
        input = {"key1": input_tensor1, "key2": input_tensor2, "key3": input_tensor3}
591
        y_ref = model(input)
592

593
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(
594
            {"key1": 0, "key2": 0, "key3": 0}
595
        )
596
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(
597
            {"key1": 0, "key2": 0, "key3": 0}
598
        )
599

600
        multi_stream_model = self.create_multi_stream_module(
601
            model,
602
            num_streams,
603
            cpu_pool,
604
            multi_stream_input_hint,
605
            multi_stream_output_hint,
606
            concat_output=True,
607
        )
608
        y_runtime_res = multi_stream_model(input)
609
        self.assertEqual(y_ref, y_runtime_res)
610

611
    @unittest.skipIf(
612
        not ipex.cpu.runtime.is_runtime_ext_enabled(),
613
        "Skip when IPEX Runtime extension is not enabled",
614
    )
615
    @runtime_thread_affinity_test_env
616
    def test_nested_tuple_input_output_hint(self):
617
        # This module:
618
        #   * Accept nested tuple ((tensor1, tensor2), tensor3) as input
619
        #   * Return nested tuple ((tensor1, tensor2), tensor3) as output
620
        model = TestInputOutputModule2().eval()
621
        batch_size, num_streams, cpu_pool = self.init_set_up()
622

623
        input_tensor1 = torch.rand(batch_size, 1)
624
        input_tensor2 = torch.rand(batch_size, 2)
625
        input_tensor3 = torch.rand(batch_size, 3)
626
        input = ((input_tensor1, input_tensor2), input_tensor3)
627
        y_ref = model(input)
628

629
        multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(((0, 0), 0))
630
        multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(((0, 0), 0))
631

632
        multi_stream_model = self.create_multi_stream_module(
633
            model,
634
            num_streams,
635
            cpu_pool,
636
            multi_stream_input_hint,
637
            multi_stream_output_hint,
638
            concat_output=True,
639
        )
640
        y_runtime_res = multi_stream_model(input)
641
        self.assertEqual(y_ref, y_runtime_res)
642

643

644
def is_numactl_available():
645
    numactl_available = False
646
    cmd = ["numactl", "-C", "0", "-m", "0", "ls"]
647
    try:
648
        r = subprocess.run(cmd, env=os.environ)
649
    except BaseException:
650
        return numactl_available
651
    if r.returncode == 0:
652
        numactl_available = True
653
    return numactl_available
654

655

656
class TestRuntimeExtensionWithNumactl(TestCase):
657
    @unittest.skipIf(
658
        not (is_numactl_available() and ipex.cpu.runtime.is_runtime_ext_enabled()),
659
        "Skip when numactl is not available",
660
    )
661
    @runtime_thread_affinity_test_env
662
    def test_cpupool_creation_with_numactl(self):
663
        loc = os.path.dirname(os.path.abspath(__file__))
664
        cmd1 = "numactl -C 0-1 -m 0 python -u {}/runtime.py --case-name={}".format(
665
            loc, "create_cpu_pool"
666
        )
667
        cmd2 = "OMP_NUM_THREADS=1 KMP_AFFINITY=granularity=fine,compact,1,0 numactl -C 0-1 -m 0 \
668
            python -u {}/runtime.py --case-name={}".format(
669
            loc, "create_cpu_pool"
670
        )
671
        cmds = [cmd1, cmd2]
672
        for cmd in cmds:
673
            match = False
674
            with subprocess.Popen(
675
                cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
676
            ) as p:
677
                for line in p.stdout.readlines():
678
                    line = str(line, "utf-8").strip()
679
                    if "The created CPUPool has core is:" in line:
680
                        x = line.split(":")
681
                        assert (
682
                            "[1]" in x[1]
683
                        ), "The core ids in test_cpupool_creation with numactl is not as expected."
684
                        match = True
685
            assert match, "Test Case Failed to create CPUPool"
686

687

688
if __name__ == "__main__":
689
    test = unittest.main()
690

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.