intel-extension-for-pytorch
689 строк · 23.1 Кб
1import unittest2import torch3import intel_extension_for_pytorch as ipex4from common_utils import TestCase5
6from common_ipex_conf import runtime_thread_affinity_test_env7import subprocess8import os9
10
11class SimpleNet(torch.nn.Module):12def __init__(self):13super(SimpleNet, self).__init__()14self.conv = torch.nn.Conv2d(1564, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False16)17
18def forward(self, x):19x1 = self.conv(x)20y = torch.flatten(x1, start_dim=1)21return y22
23
24class SimpleNet_v2(torch.nn.Module):25def __init__(self):26super(SimpleNet_v2, self).__init__()27self.conv = torch.nn.Conv2d(283, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False29)30self.conv2 = torch.nn.Conv2d(3164, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False32)33
34def forward(self, x):35x1 = self.conv(x)36x1 = self.conv2(x1)37y = torch.flatten(x1, start_dim=1)38return y39
40
41class SimpleNet_dict(torch.nn.Module):42def __init__(self):43super(SimpleNet_dict, self).__init__()44self.conv = torch.nn.Conv2d(4564, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False46)47
48def forward(self, **x_dict):49x1 = self.conv(x_dict["x1"])50x2 = self.conv(x_dict["x2"])51y1 = x1 + x252y2 = torch.flatten(y1, start_dim=1)53ret_dict = {"y1": y1, "y2": y2}54return ret_dict55
56
57class SimpleNet_tensor_dict(torch.nn.Module):58def __init__(self):59super(SimpleNet_tensor_dict, self).__init__()60self.conv = torch.nn.Conv2d(6164, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False62)63
64def forward(self, **x_dict):65x1 = self.conv(x_dict["x1"])66x2 = self.conv(x_dict["x2"])67y1 = x1 + x268y2 = torch.flatten(y1, start_dim=1)69ret_dict = {"y1": y1, "y2": y2}70# Return a tuple of (Tensor, dict)71return y1, ret_dict72
73
74class TestInputOutputModule(torch.nn.Module):75def __init__(self):76super(TestInputOutputModule, self).__init__()77
78def forward(self, *args, **kwargs):79return args80
81
82class TestInputOutputModule2(torch.nn.Module):83def __init__(self):84super(TestInputOutputModule2, self).__init__()85
86def forward(self, param1):87return param188
89
90class TestCPUPool(TestCase):91@unittest.skipIf(92not ipex.cpu.runtime.is_runtime_ext_enabled(),93"Skip when IPEX Runtime extension is not enabled",94)95def test_cpupool_get_core_list(self):96core_list = [0, 1]97cpu_pool = ipex.cpu.runtime.CPUPool(core_list)98self.assertEqual(cpu_pool.cpu_pool.get_core_list(), core_list)99
100
101class TestCoreBinding(TestCase):102@unittest.skipIf(103not ipex.cpu.runtime.is_runtime_ext_enabled(),104"Skip when IPEX Runtime extension is not enabled",105)106@runtime_thread_affinity_test_env107def test_decorator_imperative_model(self):108model = SimpleNet()109model.eval()110x = torch.rand(64, 64, 3, 3)111cpu_pool = ipex.cpu.runtime.CPUPool([1, 2, 3, 4])112
113@ipex.cpu.runtime.pin(cpu_pool)114def test(model, x):115return model(x)116
117y_runtime = test(model, x)118y = model(x)119self.assertEqual(y, y_runtime)120
121@unittest.skipIf(122not ipex.cpu.runtime.is_runtime_ext_enabled(),123"Skip when IPEX Runtime extension is not enabled",124)125@runtime_thread_affinity_test_env126def test_with_context_imperative_model(self):127model = SimpleNet()128model.eval()129x = torch.rand(64, 64, 3, 3)130cpu_pool = ipex.cpu.runtime.CPUPool([1, 2, 3, 4])131with ipex.cpu.runtime.pin(cpu_pool):132y_runtime = model(x)133y = model(x)134self.assertEqual(y, y_runtime)135
136@unittest.skipIf(137not ipex.cpu.runtime.is_runtime_ext_enabled(),138"Skip when IPEX Runtime extension is not enabled",139)140@runtime_thread_affinity_test_env141def test_nested_with_context_imperative_model(self):142model = torch.nn.Softmax(dim=-1)143model.eval()144x = torch.rand(100, 8276)145cpu_pool = ipex.cpu.runtime.CPUPool([1, 2])146cpu_pool2 = ipex.cpu.runtime.CPUPool([3, 4])147with ipex.cpu.runtime.pin(cpu_pool):148y_runtime = model(x)149with ipex.cpu.runtime.pin(cpu_pool2):150y_runtime = model(x)151y_runtime = model(x)152y = model(x)153self.assertEqual(y, y_runtime)154
155
156class TestRuntimeAPI(TestCase):157@unittest.skipIf(158not ipex.cpu.runtime.is_runtime_ext_enabled(),159"Skip when IPEX Runtime extension is not enabled",160)161@runtime_thread_affinity_test_env162def test_task_async_api_imperative_model(self):163model = SimpleNet()164model.eval()165x = torch.rand(64, 64, 3, 3)166# Calculate the reference result167y = model(x)168
169# Create task170cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)171task = ipex.cpu.runtime.Task(model, cpu_pool)172
173# Task submit and wait174y_runtime_future = task(x)175y_runtime = y_runtime_future.get()176self.assertEqual(y, y_runtime)177
178@unittest.skipIf(179not ipex.cpu.runtime.is_runtime_ext_enabled(),180"Skip when IPEX Runtime extension is not enabled",181)182@runtime_thread_affinity_test_env183def test_task_sync_api_imperative_model(self):184model = SimpleNet()185model.eval()186x = torch.rand(64, 64, 3, 3)187# Calculate the reference result188y = model(x)189
190# Create task191cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)192task = ipex.cpu.runtime.Task(model, cpu_pool)193
194# Task sync submit195y_runtime = task.run_sync(x)196self.assertEqual(y, y_runtime)197
198@unittest.skipIf(199not ipex.cpu.runtime.is_runtime_ext_enabled(),200"Skip when IPEX Runtime extension is not enabled",201)202@runtime_thread_affinity_test_env203def test_task_async_api_native_function(self):204model = SimpleNet()205model.eval()206x = torch.rand(64, 64, 3, 3)207
208def test(model, x):209return model(x)210
211# Calculate the reference result212y = test(model, x)213
214# Create task215cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)216task = ipex.cpu.runtime.Task(test, cpu_pool)217
218# Task submit and wait219y_runtime_future = task(model, x)220y_runtime = y_runtime_future.get()221self.assertEqual(y, y_runtime)222
223@unittest.skipIf(224not ipex.cpu.runtime.is_runtime_ext_enabled(),225"Skip when IPEX Runtime extension is not enabled",226)227@runtime_thread_affinity_test_env228def test_task_copy(self):229model = SimpleNet()230model.eval()231x = torch.rand(64, 64, 3, 3)232# Calculate the reference result233y = model(x)234
235# Create task236cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)237task = ipex.cpu.runtime.Task(model, cpu_pool)238
239# Copy task240task2 = task241
242# Task submit and wait243y_runtime_future = task(x)244y_runtime = y_runtime_future.get()245y_runtime_future2 = task2(x)246y_runtime2 = y_runtime_future2.get()247self.assertEqual(y, y_runtime)248self.assertEqual(y, y_runtime2)249
250
251class TestMultiStreamModule(TestCase):252@unittest.skipIf(253not ipex.cpu.runtime.is_runtime_ext_enabled(),254"Skip when IPEX Runtime extension is not enabled",255)256@runtime_thread_affinity_test_env257def test_multi_stream_module(self):258model = SimpleNet()259model.eval()260batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()261x = torch.rand(batch_size, 64, 3, 3)262
263# Calculate the reference result264y = model(x)265
266# Create MultiStreamModule267cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)268multi_stream_model = ipex.cpu.runtime.MultiStreamModule(269model, num_streams=2, cpu_pool=cpu_pool270)271
272y_runtime = multi_stream_model(x)273self.assertEqual(y, y_runtime)274
275@unittest.skipIf(276not ipex.cpu.runtime.is_runtime_ext_enabled(),277"Skip when IPEX Runtime extension is not enabled",278)279@runtime_thread_affinity_test_env280def test_multi_stream_module_with_dict_return_type(self):281model = SimpleNet_dict()282model.eval()283batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()284x1 = torch.rand(batch_size, 64, 3, 3)285x2 = torch.rand(batch_size, 64, 3, 3)286x_dict = {"x1": x1, "x2": x2}287
288# Calculate the reference result289y_dict = model(**x_dict)290
291# Create MultiStreamModule292cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)293
294input_hint_object = {"x1": 0, "x2": 0}295multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(296**input_hint_object297)298output_concat_object = {"y1": 0, "y2": 0}299multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(300**output_concat_object301)302
303multi_stream_model = ipex.cpu.runtime.MultiStreamModule(304model,305num_streams=2,306cpu_pool=cpu_pool,307input_split_hint=multi_stream_input_hint,308output_concat_hint=multi_stream_output_hint,309)310
311y_runtime_dict = multi_stream_model(**x_dict)312self.assertEqual(y_dict["y1"], y_runtime_dict["y1"])313self.assertEqual(y_dict["y2"], y_runtime_dict["y2"])314
315@unittest.skipIf(316not ipex.cpu.runtime.is_runtime_ext_enabled(),317"Skip when IPEX Runtime extension is not enabled",318)319@runtime_thread_affinity_test_env320def test_multi_stream_module_with_tensor_and_dict_return_type(self):321model = SimpleNet_tensor_dict()322model.eval()323batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()324x1 = torch.rand(batch_size, 64, 3, 3)325x2 = torch.rand(batch_size, 64, 3, 3)326x_dict = {"x1": x1, "x2": x2}327
328# Calculate the reference result329y, y_dict = model(**x_dict)330
331# Create MultiStreamModule332cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)333
334input_hint_object = {"x1": 0, "x2": 0}335multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(336**input_hint_object337)338output_concat_object = (0, {"y1": 0, "y2": 0})339multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(340output_concat_object
341)342
343multi_stream_model = ipex.cpu.runtime.MultiStreamModule(344model,345num_streams=2,346cpu_pool=cpu_pool,347input_split_hint=multi_stream_input_hint,348output_concat_hint=multi_stream_output_hint,349)350
351y_runtime, y_runtime_dict = multi_stream_model(**x_dict)352self.assertEqual(y, y_runtime)353self.assertEqual(y_dict["y1"], y_runtime_dict["y1"])354self.assertEqual(y_dict["y2"], y_runtime_dict["y2"])355
356@unittest.skipIf(357not ipex.cpu.runtime.is_runtime_ext_enabled(),358"Skip when IPEX Runtime extension is not enabled",359)360@runtime_thread_affinity_test_env361def test_single_stream_module(self):362model = SimpleNet()363model.eval()364batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()365x = torch.rand(batch_size, 64, 3, 3)366
367# Calculate the reference result368y = model(x)369
370# Create MultiStreamModule371cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)372multi_stream_model = ipex.cpu.runtime.MultiStreamModule(373model, num_streams=1, cpu_pool=cpu_pool374)375multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(376model, num_streams=1, cpu_pool=cpu_pool, concat_output=False377)378
379y_runtime = multi_stream_model(x)380y_runtime2 = multi_stream_model2(x)381self.assertEqual(y, y_runtime)382self.assertEqual(y, y_runtime2[0])383
384@unittest.skipIf(385not ipex.cpu.runtime.is_runtime_ext_enabled(),386"Skip when IPEX Runtime extension is not enabled",387)388@runtime_thread_affinity_test_env389def test_core_number_not_divisible_by_stream_number(self):390model = SimpleNet()391model.eval()392num_streams = 2393batch_size = num_streams394x = torch.rand(batch_size, 64, 3, 3)395# Calculate the reference result396y = model(x)397
398# Create MultiStreamModule399# Core Number is 3, stream Number is 2400cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])401multi_stream_model = ipex.cpu.runtime.MultiStreamModule(402model, num_streams=num_streams, cpu_pool=cpu_pool403)404multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(405model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False406)407
408y_runtime = multi_stream_model(x)409y_runtime2 = multi_stream_model2(x)410self.assertEqual(y, y_runtime)411self.assertEqual(y, torch.cat(y_runtime2))412
413@unittest.skipIf(414not ipex.cpu.runtime.is_runtime_ext_enabled(),415"Skip when IPEX Runtime extension is not enabled",416)417@runtime_thread_affinity_test_env418def test_batchsize_less_than_stream_number(self):419model = SimpleNet()420model.eval()421num_streams = 3422batch_size = 2423x = torch.rand(batch_size, 64, 3, 3)424# Calculate the reference result425y = model(x)426
427# Create MultiStreamModule428# Batchsize 2, Core Number is 3, stream Number is 3429cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])430multi_stream_model = ipex.cpu.runtime.MultiStreamModule(431model, num_streams=num_streams, cpu_pool=cpu_pool432)433multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(434model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False435)436
437y_runtime = multi_stream_model(x)438y_runtime2 = multi_stream_model2(x)439self.assertEqual(y, y_runtime)440self.assertEqual(y, torch.cat(y_runtime2))441
442@unittest.skipIf(443not ipex.cpu.runtime.is_runtime_ext_enabled(),444"Skip when IPEX Runtime extension is not enabled",445)446@runtime_thread_affinity_test_env447def test_batchsize_not_divisible_by_stream_number(self):448model = SimpleNet()449model.eval()450num_streams = 3451batch_size = 4452x = torch.rand(batch_size, 64, 3, 3)453# Calculate the reference result454y = model(x)455
456# Create MultiStreamModule457# Batchsize 4, Core Number is 3, stream Number is 3458cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])459multi_stream_model = ipex.cpu.runtime.MultiStreamModule(460model, num_streams=num_streams, cpu_pool=cpu_pool461)462multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(463model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False464)465
466y_runtime = multi_stream_model(x)467y_runtime2 = multi_stream_model2(x)468self.assertEqual(y, y_runtime)469self.assertEqual(y, torch.cat(y_runtime2))470self.assertEqual(y_runtime2[0].size(0), 2)471self.assertEqual(y_runtime2[1].size(0), 1)472self.assertEqual(y_runtime2[2].size(0), 1)473
474
475class TestModuleMultiStreamModuleHint(TestCase):476# For the inputs format which can't be jit.trace477def init_set_up(self):478# Create Multi Stream Module without concat output479cpu_pool = ipex.cpu.runtime.CPUPool()480batch_size = cpu_pool.core_ids.__len__()481num_streams = cpu_pool.core_ids.__len__()482return batch_size, num_streams, cpu_pool483
484def create_multi_stream_module(485self,486traced_model,487num_streams,488cpu_pool,489multi_stream_input_hint,490multi_stream_output_hint=None,491concat_output=True,492):493if not concat_output:494return ipex.cpu.runtime.MultiStreamModule(495traced_model,496num_streams=num_streams,497cpu_pool=cpu_pool,498concat_output=False,499input_split_hint=multi_stream_input_hint,500)501else:502return ipex.cpu.runtime.MultiStreamModule(503traced_model,504num_streams=num_streams,505cpu_pool=cpu_pool,506input_split_hint=multi_stream_input_hint,507output_concat_hint=multi_stream_output_hint,508)509
510@unittest.skipIf(511not ipex.cpu.runtime.is_runtime_ext_enabled(),512"Skip when IPEX Runtime extension is not enabled",513)514@runtime_thread_affinity_test_env515def test_mix_tensor_bool_input_output_hint(self):516# This module:517# * Accept 2 tensors + 1 scalar as input518# * Return 2 tensors + 1 scalar as output519# Since Type 'Tuple[Tensor, bool, Tensor]' cannot be traced, we put this test input type in imperative mode.520model = TestInputOutputModule().eval()521batch_size, num_streams, cpu_pool = self.init_set_up()522
523input_tensor1 = torch.rand(batch_size, 1)524input_tensor2 = torch.rand(batch_size, 3)525
526# Calculate the reference result527y_ref = model(input_tensor1, False, input_tensor2)528
529multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, None, 0)530multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, None, 0))531
532multi_stream_model = self.create_multi_stream_module(533model,534num_streams,535cpu_pool,536multi_stream_input_hint,537multi_stream_output_hint,538concat_output=True,539)540y_runtime_res = multi_stream_model(input_tensor1, False, input_tensor2)541self.assertEqual(y_ref, y_runtime_res)542
543@unittest.skipIf(544not ipex.cpu.runtime.is_runtime_ext_enabled(),545"Skip when IPEX Runtime extension is not enabled",546)547@runtime_thread_affinity_test_env548def test_tuple_input_output_hint(self):549# This module:550# * Accept 1 tuple(3 tensors) as input551# * Return 1 tuple(3 tensors) as output552model = TestInputOutputModule2().eval()553batch_size, num_streams, cpu_pool = self.init_set_up()554
555input_tensor1 = torch.rand(batch_size, 1)556input_tensor2 = torch.rand(batch_size, 2)557input_tensor3 = torch.rand(batch_size, 3)558input = (input_tensor1, input_tensor2, input_tensor3)559y_ref = model(input)560
561multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))562multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))563
564multi_stream_model = self.create_multi_stream_module(565model,566num_streams,567cpu_pool,568multi_stream_input_hint,569multi_stream_output_hint,570concat_output=True,571)572y_runtime_res = multi_stream_model(input)573self.assertEqual(y_ref, y_runtime_res)574
575@unittest.skipIf(576not ipex.cpu.runtime.is_runtime_ext_enabled(),577"Skip when IPEX Runtime extension is not enabled",578)579@runtime_thread_affinity_test_env580def test_dict_input_output_hint(self):581# This module:582# * Accept 1 dict(3 tensors) as input583# * Return 1 dict(3 tensors) as output584model = TestInputOutputModule2().eval()585batch_size, num_streams, cpu_pool = self.init_set_up()586
587input_tensor1 = torch.rand(batch_size, 1)588input_tensor2 = torch.rand(batch_size, 2)589input_tensor3 = torch.rand(batch_size, 3)590input = {"key1": input_tensor1, "key2": input_tensor2, "key3": input_tensor3}591y_ref = model(input)592
593multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(594{"key1": 0, "key2": 0, "key3": 0}595)596multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(597{"key1": 0, "key2": 0, "key3": 0}598)599
600multi_stream_model = self.create_multi_stream_module(601model,602num_streams,603cpu_pool,604multi_stream_input_hint,605multi_stream_output_hint,606concat_output=True,607)608y_runtime_res = multi_stream_model(input)609self.assertEqual(y_ref, y_runtime_res)610
611@unittest.skipIf(612not ipex.cpu.runtime.is_runtime_ext_enabled(),613"Skip when IPEX Runtime extension is not enabled",614)615@runtime_thread_affinity_test_env616def test_nested_tuple_input_output_hint(self):617# This module:618# * Accept nested tuple ((tensor1, tensor2), tensor3) as input619# * Return nested tuple ((tensor1, tensor2), tensor3) as output620model = TestInputOutputModule2().eval()621batch_size, num_streams, cpu_pool = self.init_set_up()622
623input_tensor1 = torch.rand(batch_size, 1)624input_tensor2 = torch.rand(batch_size, 2)625input_tensor3 = torch.rand(batch_size, 3)626input = ((input_tensor1, input_tensor2), input_tensor3)627y_ref = model(input)628
629multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(((0, 0), 0))630multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(((0, 0), 0))631
632multi_stream_model = self.create_multi_stream_module(633model,634num_streams,635cpu_pool,636multi_stream_input_hint,637multi_stream_output_hint,638concat_output=True,639)640y_runtime_res = multi_stream_model(input)641self.assertEqual(y_ref, y_runtime_res)642
643
644def is_numactl_available():645numactl_available = False646cmd = ["numactl", "-C", "0", "-m", "0", "ls"]647try:648r = subprocess.run(cmd, env=os.environ)649except BaseException:650return numactl_available651if r.returncode == 0:652numactl_available = True653return numactl_available654
655
656class TestRuntimeExtensionWithNumactl(TestCase):657@unittest.skipIf(658not (is_numactl_available() and ipex.cpu.runtime.is_runtime_ext_enabled()),659"Skip when numactl is not available",660)661@runtime_thread_affinity_test_env662def test_cpupool_creation_with_numactl(self):663loc = os.path.dirname(os.path.abspath(__file__))664cmd1 = "numactl -C 0-1 -m 0 python -u {}/runtime.py --case-name={}".format(665loc, "create_cpu_pool"666)667cmd2 = "OMP_NUM_THREADS=1 KMP_AFFINITY=granularity=fine,compact,1,0 numactl -C 0-1 -m 0 \668python -u {}/runtime.py --case-name={}".format(669loc, "create_cpu_pool"670)671cmds = [cmd1, cmd2]672for cmd in cmds:673match = False674with subprocess.Popen(675cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT676) as p:677for line in p.stdout.readlines():678line = str(line, "utf-8").strip()679if "The created CPUPool has core is:" in line:680x = line.split(":")681assert (682"[1]" in x[1]683), "The core ids in test_cpupool_creation with numactl is not as expected."684match = True685assert match, "Test Case Failed to create CPUPool"686
687
688if __name__ == "__main__":689test = unittest.main()690