intel-extension-for-pytorch
965 строк · 33.7 Кб
1import unittest2import torch3import intel_extension_for_pytorch as ipex4from torch.testing._internal.jit_utils import JitTestCase5from test_ao_jit_llga_utils import JitLlgaTestCase6from test_runtime_api import TestInputOutputModule7from common_ipex_conf import runtime_thread_affinity_test_env8
9
10class SimpleNet(torch.nn.Module):11def __init__(self):12super(SimpleNet, self).__init__()13self.conv = torch.nn.Conv2d(1464, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False15)16
17def forward(self, x):18x1 = self.conv(x)19y = torch.flatten(x1, start_dim=1)20return y21
22
23class SimpleNet_v2(torch.nn.Module):24def __init__(self):25super(SimpleNet_v2, self).__init__()26self.conv = torch.nn.Conv2d(273, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False28)29self.conv2 = torch.nn.Conv2d(3064, 64, (3, 3), stride=(2, 2), padding=(1, 1), bias=False31)32
33def forward(self, x):34x1 = self.conv(x)35x1 = self.conv2(x1)36y = torch.flatten(x1, start_dim=1)37return y38
39
40class TestJitRuntimeAPI(JitTestCase):41@unittest.skipIf(42not ipex.cpu.runtime.is_runtime_ext_enabled(),43"Skip when IPEX Runtime extension is not enabled",44)45@runtime_thread_affinity_test_env46def test_task_async_api_fp32_jit_model(self):47model = SimpleNet()48model.eval()49x = torch.rand(64, 64, 3, 3)50
51# Calculate the reference result52trace_model = torch.jit.trace(model, x)53y = trace_model(x)54
55# Create task56cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)57task = ipex.cpu.runtime.Task(trace_model, cpu_pool)58
59# Task submit and get60y_runtime_future = task(x)61y_runtime = y_runtime_future.get()62self.assertEqual(y, y_runtime)63
64@unittest.skipIf(65not ipex.cpu.runtime.is_runtime_ext_enabled(),66"Skip when IPEX Runtime extension is not enabled",67)68@runtime_thread_affinity_test_env69def test_task_sync_api_fp32_jit_model(self):70model = SimpleNet()71model.eval()72x = torch.rand(64, 64, 3, 3)73
74# Calculate the reference result75trace_mode = torch.jit.trace(model, x)76y = trace_mode(x)77
78# Create task79cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)80task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)81
82# Task sync run83y_runtime = task.run_sync(x)84self.assertEqual(y, y_runtime)85
86@unittest.skipIf(87not ipex.cpu.runtime.is_runtime_ext_enabled(),88"Skip when IPEX Runtime extension is not enabled",89)90@runtime_thread_affinity_test_env91def test_task_async_api_bf16_jit_model(self):92model = SimpleNet()93model.eval()94x = torch.rand(64, 64, 3, 3)95
96# Calculate the reference result97with torch.cpu.amp.autocast(98enabled=True, dtype=torch.bfloat1699), torch.no_grad():100trace_mode = torch.jit.trace(model, x)101y = trace_mode(x)102
103# Create task104cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)105task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)106
107# Task submit and wait108y_runtime_future = task(x)109y_runtime = y_runtime_future.get()110self.assertEqual(y, y_runtime)111
112@unittest.skipIf(113not ipex.cpu.runtime.is_runtime_ext_enabled(),114"Skip when IPEX Runtime extension is not enabled",115)116@runtime_thread_affinity_test_env117def test_task_async_api_bf16_jit_model_multi_submission(self):118model = SimpleNet()119model.eval()120x = torch.rand(64, 64, 3, 3)121
122# Calculate the reference result123with torch.cpu.amp.autocast(124enabled=True, dtype=torch.bfloat16125), torch.no_grad():126trace_mode = torch.jit.trace(model, x)127y = trace_mode(x)128
129# Create task130cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)131task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)132
133# Submit task 3 times, then wait for result134y_runtime = []135y_runtime_future = []136for i in range(3):137y_runtime_future.append(task(x))138for item in y_runtime_future:139y_runtime.append(item.get())140
141self.assertEqual(y, y_runtime[0])142self.assertEqual(y, y_runtime[1])143self.assertEqual(y, y_runtime[2])144
145@unittest.skipIf(146not ipex.cpu.runtime.is_runtime_ext_enabled(),147"Skip when IPEX Runtime extension is not enabled",148)149@runtime_thread_affinity_test_env150def test_task_copy_bf16_jit_mode(self):151model = SimpleNet()152model.eval()153x = torch.rand(64, 64, 3, 3)154
155# Calculate the reference result156with torch.cpu.amp.autocast(157enabled=True, dtype=torch.bfloat16158), torch.no_grad():159trace_mode = torch.jit.trace(model, x)160y = trace_mode(x)161
162# Create task163cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)164task = ipex.cpu.runtime.Task(trace_mode, cpu_pool)165
166# Copy Task167task2 = task168
169# Task submit and wait170y_runtime_future = task(x)171y_runtime = y_runtime_future.get()172y_runtime_future2 = task2(x)173y_runtime2 = y_runtime_future2.get()174self.assertEqual(y, y_runtime)175self.assertEqual(y, y_runtime2)176
177
178class TestJITMultiStreamModule(JitTestCase):179@unittest.skipIf(180not ipex.cpu.runtime.is_runtime_ext_enabled(),181"Skip when IPEX Runtime extension is not enabled",182)183@runtime_thread_affinity_test_env184def test_multi_stream_module_bf16_jit_model(self):185model = SimpleNet()186model.eval()187cpu_pool = ipex.cpu.runtime.CPUPool()188batch_size = cpu_pool.core_ids.__len__()189x = torch.rand(batch_size, 64, 3, 3)190num_streams = batch_size191
192# Calculate the reference result193with torch.cpu.amp.autocast(194enabled=True, dtype=torch.bfloat16195), torch.no_grad():196trace_model = torch.jit.trace(model, x)197y = trace_model(x)198
199# Create MultiStreamModule200cpu_pool = ipex.cpu.runtime.CPUPool()201multi_stream_model = ipex.cpu.runtime.MultiStreamModule(202trace_model, num_streams=num_streams, cpu_pool=cpu_pool203)204
205y_runtime = multi_stream_model(x)206self.assertEqual(y, y_runtime)207
208@unittest.skipIf(209not ipex.cpu.runtime.is_runtime_ext_enabled(),210"Skip when IPEX Runtime extension is not enabled",211)212@runtime_thread_affinity_test_env213def test_multi_stream_module_bf16_jit_model_concat_output(self):214model = SimpleNet()215model.eval()216
217cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)218batch_size = cpu_pool.core_ids.__len__()219x = torch.rand(batch_size, 64, 3, 3)220num_streams = batch_size221
222# Calculate the reference result223with torch.cpu.amp.autocast(224enabled=True, dtype=torch.bfloat16225), torch.no_grad():226trace_model = torch.jit.trace(model, x)227
228# Create MultiStreamModule229multi_stream_model = ipex.cpu.runtime.MultiStreamModule(230trace_model, num_streams=num_streams, cpu_pool=cpu_pool231)232y_runtime = multi_stream_model(x)233
234# Create MultiStreamModule with concat_output=False235multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(236trace_model, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False237)238y_runtime2 = multi_stream_model2(x)239self.assertEqual(y_runtime2.__len__(), num_streams)240self.assertEqual(y_runtime, torch.cat(y_runtime2))241
242@unittest.skipIf(243not ipex.cpu.runtime.is_runtime_ext_enabled(),244"Skip when IPEX Runtime extension is not enabled",245)246@runtime_thread_affinity_test_env247def test_single_stream_module_bf16_jit_model(self):248model = SimpleNet()249model.eval()250batch_size = ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()251x = torch.rand(batch_size, 64, 3, 3)252
253# Calculate the reference result254with torch.cpu.amp.autocast(255enabled=True, dtype=torch.bfloat16256), torch.no_grad():257trace_model = torch.jit.trace(model, x)258
259y = trace_model(x)260
261# Create MultiStreamModule262cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)263multi_stream_model = ipex.cpu.runtime.MultiStreamModule(264trace_model, num_streams=1, cpu_pool=cpu_pool265)266y_runtime = multi_stream_model(x)267
268# Create MultiStreamModule with concat_output=False269multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(270trace_model, num_streams=1, cpu_pool=cpu_pool, concat_output=False271)272y_runtime2 = multi_stream_model2(x)273self.assertEqual(y, y_runtime)274self.assertEqual(y, y_runtime2[0])275
276@unittest.skipIf(277not ipex.cpu.runtime.is_runtime_ext_enabled(),278"Skip when IPEX Runtime extension is not enabled",279)280@runtime_thread_affinity_test_env281def test_core_number_not_divisible_stream_number_bf16_jit_model(self):282model = SimpleNet()283model.eval()284num_streams = 2285batch_size = num_streams286x = torch.rand(batch_size, 64, 3, 3)287
288# Calculate the reference result289with torch.cpu.amp.autocast(290enabled=True, dtype=torch.bfloat16291), torch.no_grad():292traced_model = torch.jit.trace(model, x)293traced_model = torch.jit.freeze(traced_model)294
295# Calculate the reference result296y = traced_model(x)297
298# Create MultiStreamModule299# Core Number is 3, stream Number is 2300cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])301multi_stream_model = ipex.cpu.runtime.MultiStreamModule(302traced_model, num_streams=num_streams, cpu_pool=cpu_pool303)304multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(305traced_model,306num_streams=num_streams,307cpu_pool=cpu_pool,308concat_output=False,309)310
311y_runtime = multi_stream_model(x)312y_runtime2 = multi_stream_model2(x)313self.assertEqual(y, y_runtime)314self.assertEqual(y, torch.cat(y_runtime2))315
316@unittest.skipIf(317not ipex.cpu.runtime.is_runtime_ext_enabled(),318"Skip when IPEX Runtime extension is not enabled",319)320@runtime_thread_affinity_test_env321def test_batchsize_less_than_stream_number_bf16_jit_model(self):322model = SimpleNet()323model.eval()324num_streams = 3325batch_size = 2326x = torch.rand(batch_size, 64, 3, 3)327
328# Calculate the reference result329with torch.cpu.amp.autocast(330enabled=True, dtype=torch.bfloat16331), torch.no_grad():332traced_model = torch.jit.trace(model, x)333traced_model = torch.jit.freeze(traced_model)334
335# Calculate the reference result336y = traced_model(x)337
338# Create MultiStreamModule339# Batchsize 2, Core Number is 3, stream Number is 3340cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])341multi_stream_model = ipex.cpu.runtime.MultiStreamModule(342traced_model, num_streams=num_streams, cpu_pool=cpu_pool343)344multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(345traced_model,346num_streams=num_streams,347cpu_pool=cpu_pool,348concat_output=False,349)350
351y_runtime = multi_stream_model(x)352y_runtime2 = multi_stream_model2(x)353self.assertEqual(y, y_runtime)354self.assertEqual(y, torch.cat(y_runtime2))355
356@unittest.skipIf(357not ipex.cpu.runtime.is_runtime_ext_enabled(),358"Skip when IPEX Runtime extension is not enabled",359)360@runtime_thread_affinity_test_env361def test_batchsize_not_divisible_stream_number_bf16_jit_model(self):362model = SimpleNet()363model.eval()364num_streams = 3365batch_size = 4366x = torch.rand(batch_size, 64, 3, 3)367
368# Calculate the reference result369with torch.cpu.amp.autocast(370enabled=True, dtype=torch.bfloat16371), torch.no_grad():372traced_model = torch.jit.trace(model, x)373traced_model = torch.jit.freeze(traced_model)374
375# Calculate the reference result376y = traced_model(x)377
378# Create MultiStreamModule379# Batchsize 4, Core Number is 3, stream Number is 3380cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])381multi_stream_model = ipex.cpu.runtime.MultiStreamModule(382traced_model, num_streams=num_streams, cpu_pool=cpu_pool383)384multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(385traced_model,386num_streams=num_streams,387cpu_pool=cpu_pool,388concat_output=False,389)390
391y_runtime = multi_stream_model(x)392y_runtime2 = multi_stream_model2(x)393
394self.assertEqual(y, y_runtime)395self.assertEqual(y, torch.cat(y_runtime2))396self.assertEqual(y_runtime2[0].size(0), 2)397self.assertEqual(y_runtime2[1].size(0), 1)398self.assertEqual(y_runtime2[2].size(0), 1)399
400@unittest.skipIf(401not ipex.cpu.runtime.is_runtime_ext_enabled(),402"Skip when IPEX Runtime extension is not enabled",403)404@runtime_thread_affinity_test_env405def test_stream_number_auto_bf16_jit_model(self):406model = torch.nn.Softmax(dim=-1)407model.eval()408for i in range(ipex.cpu.runtime.get_core_list_of_node_id(0).__len__()):409batch_size = list(range(i + 1)).__len__()410x = torch.rand(batch_size, 64)411
412# Calculate the reference result413with torch.cpu.amp.autocast(414enabled=True, dtype=torch.bfloat16415), torch.no_grad():416traced_model = torch.jit.trace(model, x)417traced_model = torch.jit.freeze(traced_model)418
419# Warm Up420for _ in range(3):421traced_model(x)422
423# Calculate the reference result424y = traced_model(x)425
426cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=list(range(i + 1)))427
428# The stream number will be determined automatically.429multi_stream_model = ipex.cpu.runtime.MultiStreamModule(430traced_model, cpu_pool=cpu_pool431)432y_runtime = multi_stream_model(x)433stream_num_ground_truth = ipex.cpu.runtime.get_default_num_streams(cpu_pool)434self.assertEqual(y, y_runtime)435self.assertEqual(436multi_stream_model.get_stream_number(), stream_num_ground_truth437)438
439@unittest.skipIf(440not ipex.cpu.runtime.is_runtime_ext_enabled(),441"Skip when IPEX Runtime extension is not enabled",442)443@runtime_thread_affinity_test_env444def test_stream_number_larger_than_core_number(self):445model = torch.nn.Softmax(dim=-1)446model.eval()447
448cpu_pool = ipex.cpu.runtime.CPUPool()449batch_size = cpu_pool.core_ids.__len__()450num_streams = batch_size + 1451x = torch.rand(batch_size, 64)452
453# Calculate the reference result454with torch.cpu.amp.autocast(455enabled=True, dtype=torch.bfloat16456), torch.no_grad():457traced_model = torch.jit.trace(model, x)458traced_model = torch.jit.freeze(traced_model)459
460# Warm Up461for _ in range(3):462traced_model(x)463
464# Calculate the reference result465y = traced_model(x)466
467# The stream number will be determined automatically.468multi_stream_model = ipex.cpu.runtime.MultiStreamModule(469traced_model, num_streams=num_streams, cpu_pool=cpu_pool470)471y_runtime = multi_stream_model(x)472stream_num_ground_truth = ipex.cpu.runtime.get_default_num_streams(cpu_pool)473self.assertEqual(y, y_runtime)474self.assertEqual(475multi_stream_model.get_stream_number(), cpu_pool.core_ids.__len__()476)477
478
479class TestLLGARuntimeAPI(JitLlgaTestCase):480@unittest.skipIf(481not ipex.cpu.runtime.is_runtime_ext_enabled(),482"Skip when IPEX Runtime extension is not enabled",483)484@runtime_thread_affinity_test_env485def test_task_async_api_int8_jit_model(self):486with torch.no_grad():487model = SimpleNet_v2()488model.eval()489x = torch.rand(2, 3, 224, 224).contiguous(memory_format=torch.channels_last)490
491# Calculate the reference result492graph, m_llga, m_cpu = self.prepareModel(model, [x])493y = m_llga(x)494
495# Create task496cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)497task = ipex.cpu.runtime.Task(m_llga, cpu_pool)498
499# Task submit and wait500y_runtime_future = task(x)501y_runtime = y_runtime_future.get()502self.assertEqual(y, y_runtime)503
504@unittest.skipIf(505not ipex.cpu.runtime.is_runtime_ext_enabled(),506"Skip when IPEX Runtime extension is not enabled",507)508@runtime_thread_affinity_test_env509def test_multi_stream_module_int8_jit_model(self):510with torch.no_grad():511model = SimpleNet_v2()512model.eval()513x = torch.rand(2, 3, 224, 224).contiguous(memory_format=torch.channels_last)514
515# Calculate the reference result516graph, m_llga, m_cpu = self.prepareModel(model, [x])517y = m_llga(x)518
519# Create task520cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)521multi_stream_model = ipex.cpu.runtime.MultiStreamModule(522m_llga, num_streams=1, cpu_pool=cpu_pool523)524multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(525m_llga, num_streams=1, cpu_pool=cpu_pool, concat_output=False526)527
528# Task submit and wait529y_runtime = multi_stream_model(x)530y_runtime2 = multi_stream_model2(x)531self.assertEqual(y, y_runtime)532self.assertEqual(y, torch.cat(y_runtime2))533
534@unittest.skipIf(535not ipex.cpu.runtime.is_runtime_ext_enabled(),536"Skip when IPEX Runtime extension is not enabled",537)538@runtime_thread_affinity_test_env539def test_core_number_not_divisible_stream_number_int8_jit_model(self):540with torch.no_grad():541model = SimpleNet_v2()542model.eval()543num_streams = 2544batch_size = num_streams545x = torch.rand(batch_size, 3, 16, 16).contiguous(546memory_format=torch.channels_last547)548
549# Calculate the reference result550graph, m_llga, m_cpu = self.prepareModel(model, [x])551y = m_llga(x)552
553# Create MultiStreamModule554# Core Number is 3, stream Number is 2555cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])556multi_stream_model = ipex.cpu.runtime.MultiStreamModule(557m_llga, num_streams=num_streams, cpu_pool=cpu_pool558)559multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(560m_llga, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False561)562
563# Task submit and wait564y_runtime = multi_stream_model(x)565y_runtime2 = multi_stream_model2(x)566self.assertEqual(y, y_runtime)567self.assertEqual(y, torch.cat(y_runtime2))568
569@unittest.skipIf(570not ipex.cpu.runtime.is_runtime_ext_enabled(),571"Skip when IPEX Runtime extension is not enabled",572)573@runtime_thread_affinity_test_env574def test_batchsize_less_than_stream_number_int8_jit_model(self):575with torch.no_grad():576model = SimpleNet_v2()577model.eval()578num_streams = 3579batch_size = 2580x = torch.rand(batch_size, 3, 16, 16).contiguous(581memory_format=torch.channels_last582)583
584# Calculate the reference result585graph, m_llga, m_cpu = self.prepareModel(model, [x])586y = m_llga(x)587
588# Create MultiStreamModule589# Batchsize is 2, Core Number is 3, stream Number is 3590cpu_pool = ipex.cpu.runtime.CPUPool(core_ids=[0, 1, 2])591multi_stream_model = ipex.cpu.runtime.MultiStreamModule(592m_llga, num_streams=num_streams, cpu_pool=cpu_pool593)594multi_stream_model2 = ipex.cpu.runtime.MultiStreamModule(595m_llga, num_streams=num_streams, cpu_pool=cpu_pool, concat_output=False596)597
598# Task submit and wait599y_runtime = multi_stream_model(x)600y_runtime2 = multi_stream_model2(x)601self.assertEqual(y, y_runtime)602self.assertEqual(y, torch.cat(y_runtime2))603self.assertEqual(y_runtime2.__len__(), batch_size)604
605
606class TestMultiStreamModuleHint(JitTestCase):607def init_set_up(self):608# Create Multi Stream Module without concat output609cpu_pool = ipex.cpu.runtime.CPUPool()610batch_size = cpu_pool.core_ids.__len__()611num_streams = cpu_pool.core_ids.__len__()612return batch_size, num_streams, cpu_pool613
614def create_jit_traced_model(self, model, input):615traced_model = torch.jit.trace(model, input).eval()616traced_model = torch.jit.freeze(traced_model)617return traced_model618
619def create_multi_stream_module(620self,621traced_model,622num_streams,623cpu_pool,624multi_stream_input_hint,625multi_stream_output_hint=None,626concat_output=True,627):628if not concat_output:629return ipex.cpu.runtime.MultiStreamModule(630traced_model,631num_streams=num_streams,632cpu_pool=cpu_pool,633concat_output=False,634input_split_hint=multi_stream_input_hint,635)636
637else:638return ipex.cpu.runtime.MultiStreamModule(639traced_model,640num_streams=num_streams,641cpu_pool=cpu_pool,642input_split_hint=multi_stream_input_hint,643output_concat_hint=multi_stream_output_hint,644)645
646@unittest.skipIf(647not ipex.cpu.runtime.is_runtime_ext_enabled(),648"Skip when IPEX Runtime extension is not enabled",649)650@runtime_thread_affinity_test_env651def test_input_output_hint(self):652batch_size, num_streams, cpu_pool = self.init_set_up()653
654# This module:655# * Accept 3 tensors as input656# * Return a tuple of 3 tensors as output657model = TestInputOutputModule().eval()658for batch_size in (num_streams - 1, num_streams):659# There is test for when batch_size is less than num_streams660input_tensor1 = torch.rand(batch_size, 1)661input_tensor2 = torch.rand(batch_size, 1)662input_tensor3 = torch.rand(batch_size, 1)663
664# Since jit trace only accept single tensor or a tuple of tensors as input665# https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch-jit-trace666jit_input = (input_tensor1, input_tensor2, input_tensor3)667
668traced_model = self.create_jit_traced_model(model, jit_input)669
670# Warm Up in the main thread to finish the jit pass optimizations671for _ in range(3):672traced_model(input_tensor1, input_tensor2, input_tensor3)673
674# Calculate the reference result675y_ref = traced_model(input_tensor1, input_tensor2, input_tensor3)676
677multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, 0, 0)678
679multi_stream_model = self.create_multi_stream_module(680traced_model,681num_streams,682cpu_pool,683multi_stream_input_hint,684concat_output=False,685)686y_runtime = multi_stream_model(input_tensor1, input_tensor2, input_tensor3)687
688# Manually concat the output689y_runtime_res1 = []690y_runtime_res2 = []691y_runtime_res3 = []692for stream_id in range(693num_streams if ((batch_size // num_streams) >= 1) else batch_size694):695y_runtime_res1.append(y_runtime[stream_id][0])696y_runtime_res2.append(y_runtime[stream_id][1])697y_runtime_res3.append(y_runtime[stream_id][2])698y_runtime_res = (699torch.cat(y_runtime_res1),700torch.cat(y_runtime_res2),701torch.cat(y_runtime_res3),702)703self.assertEqual(y_ref, y_runtime_res)704
705# Create Multi Stream Module with concat output706multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))707
708multi_stream_model2 = self.create_multi_stream_module(709traced_model,710num_streams,711cpu_pool,712multi_stream_input_hint,713multi_stream_output_hint,714concat_output=True,715)716y_runtime_res2 = multi_stream_model2(717input_tensor1, input_tensor2, input_tensor3718)719self.assertEqual(y_ref, y_runtime_res2)720
721@unittest.skipIf(722not ipex.cpu.runtime.is_runtime_ext_enabled(),723"Skip when IPEX Runtime extension is not enabled",724)725@runtime_thread_affinity_test_env726def test_simulate_bert_large_input_output(self):727class TestModule(torch.nn.Module):728def __init__(self):729super(TestModule, self).__init__()730
731def forward(self, key1, key2, key3):732return key1 * 2, key2 * 2733
734# This module simulates the behaviour of Bert Large LZ models:735# * Accept 3 tensors (with key word) as input736# * Return a tuple of 2 tensors as output737model = TestModule().eval()738
739batch_size, num_streams, cpu_pool = self.init_set_up()740jit_input = (741torch.rand(batch_size, 1),742torch.rand(batch_size, 2),743torch.rand(batch_size, 3),744)745traced_model = self.create_jit_traced_model(model, jit_input)746
747input_tensor1 = torch.rand(batch_size, 1)748input_tensor2 = torch.rand(batch_size, 1)749input_tensor3 = torch.rand(batch_size, 1)750
751# Warm Up752for _ in range(3):753traced_model(key1=input_tensor1, key2=input_tensor2, key3=input_tensor3)754
755# Calculate the reference result756y_ref = traced_model(key1=input_tensor1, key2=input_tensor2, key3=input_tensor3)757
758multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(759key1=0, key2=0, key3=0760)761
762multi_stream_model = self.create_multi_stream_module(763traced_model,764num_streams,765cpu_pool,766multi_stream_input_hint,767concat_output=False,768)769y_runtime = multi_stream_model(770key1=input_tensor1, key2=input_tensor2, key3=input_tensor3771)772
773# Manually Concat the output774y_runtime_res1 = []775y_runtime_res2 = []776for i in range(num_streams):777y_runtime_res1.append(y_runtime[i][0])778y_runtime_res2.append(y_runtime[i][1])779y_runtime_res = (torch.cat(y_runtime_res1), torch.cat(y_runtime_res2))780self.assertEqual(y_ref, y_runtime_res)781
782multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0))783
784multi_stream_model2 = self.create_multi_stream_module(785traced_model,786num_streams,787cpu_pool,788multi_stream_input_hint,789multi_stream_output_hint,790concat_output=True,791)792y_runtime_res2 = multi_stream_model2(793key1=input_tensor1, key2=input_tensor2, key3=input_tensor3794)795
796self.assertEqual(y_ref, y_runtime_res2)797
798@unittest.skipIf(799not ipex.cpu.runtime.is_runtime_ext_enabled(),800"Skip when IPEX Runtime extension is not enabled",801)802@runtime_thread_affinity_test_env803def test_mix_position_keyword_input_output_hint(self):804class TestModule(torch.nn.Module):805def __init__(self):806super(TestModule, self).__init__()807
808def forward(self, param1, param2, key1=None):809return param1, param2, key1810
811batch_size, num_streams, cpu_pool = self.init_set_up()812# This module simulates the behaviour of Bert Large LZ models:813# * Accept 3 tensors (2 position parameter and 1 key word parameter) as input814# * Return a tuple of 3 tensors as output815model = TestModule().eval()816
817jit_input = (818torch.rand(batch_size, 1),819torch.rand(batch_size, 2),820torch.rand(batch_size, 3),821)822
823traced_model = self.create_jit_traced_model(model, jit_input)824
825input_tensor1 = torch.rand(batch_size, 1)826input_tensor2 = torch.rand(batch_size, 2)827input_tensor3 = torch.rand(batch_size, 3)828input = (input_tensor1, input_tensor2)829k_input = {"key1": input_tensor3}830
831# Warm Up832for _ in range(3):833traced_model(input_tensor1, input_tensor2, key1=input_tensor3)834
835# Calculate the reference result836y_ref = traced_model(*input, **k_input)837y_ref2 = traced_model(input_tensor1, input_tensor2, input_tensor3)838y_ref3 = traced_model(input_tensor1, input_tensor2, key1=input_tensor3)839self.assertEqual(y_ref, y_ref2)840self.assertEqual(y_ref, y_ref3)841
842# Be careful, jit traced model will change the input type843multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0, 0, key1=0)844
845# Create Multi Stream Module with concat output846multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((0, 0, 0))847
848multi_stream_model = self.create_multi_stream_module(849traced_model,850num_streams,851cpu_pool,852multi_stream_input_hint,853multi_stream_output_hint,854concat_output=True,855)856# There are 2 ways to write now857y_runtime_res = multi_stream_model(858input_tensor1, input_tensor2, key1=input_tensor3859)860y_runtime_res2 = multi_stream_model(*input, **k_input)861self.assertEqual(y_ref, y_runtime_res)862self.assertEqual(y_ref, y_runtime_res2)863
864@unittest.skipIf(865not ipex.cpu.runtime.is_runtime_ext_enabled(),866"Skip when IPEX Runtime extension is not enabled",867)868@runtime_thread_affinity_test_env869def test_input_output_hint_not_along_dim_zero(self):870batch_size, num_streams, cpu_pool = self.init_set_up()871
872# This module:873# * Accept 3 tensors as input874# * Return a tuple of 3 tensors as output875model = TestInputOutputModule().eval()876
877input_tensor1 = torch.rand(1, batch_size)878input_tensor2 = torch.rand(batch_size, 2)879input_tensor3 = torch.rand(3, batch_size)880
881# Since jit trace only accept single tensor or a tuple of tensors as input882# https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch-jit-trace883jit_input = (input_tensor1, input_tensor2, input_tensor3)884
885traced_model = self.create_jit_traced_model(model, jit_input)886
887# Warm Up in the main thread to finish the jit pass optimizations888for _ in range(3):889traced_model(input_tensor1, input_tensor2, input_tensor3)890
891# Calculate the reference result892y_ref = traced_model(input_tensor1, input_tensor2, input_tensor3)893
894multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(1, 0, 1)895
896multi_stream_model = self.create_multi_stream_module(897traced_model,898num_streams,899cpu_pool,900multi_stream_input_hint,901concat_output=False,902)903y_runtime = multi_stream_model(input_tensor1, input_tensor2, input_tensor3)904
905# Manually concat the output906y_runtime_res1 = []907y_runtime_res2 = []908y_runtime_res3 = []909for stream_id in range(910num_streams if ((batch_size // num_streams) >= 1) else batch_size911):912y_runtime_res1.append(y_runtime[stream_id][0])913y_runtime_res2.append(y_runtime[stream_id][1])914y_runtime_res3.append(y_runtime[stream_id][2])915y_runtime_res = (916torch.cat(y_runtime_res1, 1),917torch.cat(y_runtime_res2, 0),918torch.cat(y_runtime_res3, 1),919)920self.assertEqual(y_ref, y_runtime_res)921
922# Create Multi Stream Module with concat output923multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint((1, 0, 1))924
925multi_stream_model2 = self.create_multi_stream_module(926traced_model,927num_streams,928cpu_pool,929multi_stream_input_hint,930multi_stream_output_hint,931concat_output=True,932)933y_runtime_res2 = multi_stream_model2(934input_tensor1, input_tensor2, input_tensor3935)936self.assertEqual(y_ref, y_runtime_res2)937
938
939class TestMultiStreamBenchmarkModule(JitTestCase):940@unittest.skipIf(941not ipex.cpu.runtime.is_runtime_ext_enabled(),942"Skip when IPEX Runtime extension is not enabled",943)944@runtime_thread_affinity_test_env945def test_multi_stream_benchmark_module_bf16_jit_model(self):946model = SimpleNet().eval()947batch_size = 1948x = torch.rand(batch_size, 64, 3, 3)949
950# Calculate the reference result951with torch.cpu.amp.autocast(952enabled=True, dtype=torch.bfloat16953), torch.no_grad():954trace_model = torch.jit.trace(model, x)955# Warm Up956for _ in range(3):957trace_model(x)958
959# Create MultiStreamModule960multi_stream_model = ipex.cpu.runtime._MultiStreamBenchmarkModule(trace_model)961multi_stream_model(x)962
963
964if __name__ == "__main__":965test = unittest.main()966