intel-extension-for-pytorch
122 строки · 3.5 Кб
1import unittest2import os3import subprocess4
5import intel_extension_for_pytorch._C as core6
7supported_isa_set = [8"default",9"avx2",10"avx2_vnni",11"avx512",12"avx512_vnni",13"avx512_bf16",14"amx",15"avx512_fp16",16]
17
18
19def get_isa_val(isa_name):20if isa_name == "default":21return 022elif isa_name == "avx2":23return 124elif isa_name == "avx2_vnni":25return 226elif isa_name == "avx512":27return 328elif isa_name == "avx512_vnni":29return 430elif isa_name == "avx512_bf16":31return 532elif isa_name == "amx":33return 634elif isa_name == "avx512_fp16":35return 736else:37return 10038
39
40def get_ipex_isa_env_setting():41env_isa = os.getenv("ATEN_CPU_CAPABILITY")42return env_isa43
44
45def get_currnet_isa_level():46return core._get_current_isa_level().lower()47
48
49def get_highest_binary_support_isa_level():50return core._get_highest_binary_support_isa_level().lower()51
52
53def get_highest_cpu_support_isa_level():54return core._get_highest_cpu_support_isa_level().lower()55
56
57def check_not_sync_onednn_isa_level():58return core._check_not_sync_onednn_isa_level()59
60
61class TestDynDisp(unittest.TestCase):62def test_manual_select_kernel(self):63env_isa = get_ipex_isa_env_setting()64cur_isa = get_currnet_isa_level()65max_bin_isa = get_highest_binary_support_isa_level()66max_cpu_isa = get_highest_cpu_support_isa_level()67
68expected_isa_val = min(get_isa_val(max_bin_isa), get_isa_val(max_cpu_isa))69
70if env_isa is not None:71expected_isa_val = min(get_isa_val(env_isa), expected_isa_val)72
73actural_isa_val = get_isa_val(cur_isa)74
75# Isa level and compiler version are not linear relationship.76# gcc 9.4 can build avx512_vnni.77# gcc 11.3 start to support avx2_vnni.78self.assertTrue(actural_isa_val <= expected_isa_val)79return80
81def test_dyndisp_in_supported_set(self):82env_isa = get_ipex_isa_env_setting()83
84if env_isa is not None:85return86
87cur_isa = get_currnet_isa_level()88expected_isa = cur_isa in supported_isa_set89
90self.assertTrue(expected_isa)91return92
93@unittest.skipIf(94check_not_sync_onednn_isa_level(), "skip this if not sync onednn isa level"95)96def test_ipex_set_onednn_isa_level(self):97command = 'ATEN_CPU_CAPABILITY=avx2 python -c "import torch; import intel_extension_for_pytorch._C \98as core; print(core._get_current_onednn_isa_level())" '99with subprocess.Popen(100command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT101) as p:102out = p.stdout.readlines()103onednn_isa_level = str(out[-1], "utf-8").strip()104self.assertTrue(onednn_isa_level == "AVX2")105
106@unittest.skipIf(107check_not_sync_onednn_isa_level(), "skip this if not sync onednn isa level"108)109def test_onednn_do_not_set_isa_level(self):110command = 'ONEDNN_MAX_CPU_ISA=avx2 python -c "import torch; import intel_extension_for_pytorch._C \111as core; print(core._get_current_isa_level().lower())" '112cur_ipex_isa = get_currnet_isa_level()113with subprocess.Popen(114command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT115) as p:116out = p.stdout.readlines()117cur_ipex_isa_1 = str(out[-1], "utf-8").strip()118self.assertTrue(cur_ipex_isa == cur_ipex_isa_1)119
120
121if __name__ == "__main__":122unittest.main()123