intel-extension-for-pytorch
139 строк · 5.9 Кб
1import unittest2from common_utils import TestCase3import os4import subprocess5import itertools6import torch7import logging8
9logging.getLogger().setLevel(logging.DEBUG)10
11
12class TestCodeFreeOptimization(TestCase):13def test_conv_bn(self):14loc = os.path.dirname(os.path.abspath(__file__))15disable_ipex_graph_modes = [False, True]16dtypes = (17["float32", "bfloat16"]18if torch.ops.mkldnn._is_mkldnn_bf16_supported()19else ["float32"]20)21for disable_ipex_graph_mode, dtype in itertools.product(22disable_ipex_graph_modes, dtypes23):24_ipex_optimize_hit_count = 025_ipex_convolution = False26_has_batchnorm = False27cmd = "ipexrun --ninstances 1 "28cmd += "--auto-ipex "29cmd += "--dtype {} ".format(dtype)30cmd += "--auto-ipex-verbose "31if disable_ipex_graph_mode:32cmd += "--disable-ipex-graph-mode "33cmd += "{}/code_free_optimization.py --conv_bn".format(loc)34with subprocess.Popen(35cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT36) as p:37for line in p.stdout.readlines():38line = str(line, "utf-8").strip()39if line.__contains__("_ipex_optimize_hit_count"):40_ipex_optimize_hit_count = _ipex_optimize_hit_count + 141if line.__contains__(42"torch_ipex::convolution_forward_impl"43if disable_ipex_graph_mode44else "ipex_prepack::convolution_run"45) or line.__contains__("fused_to"):46_ipex_convolution = True47if line.__contains__("batch_norm"):48_has_batchnorm = True49assert (50_ipex_optimize_hit_count == 151), "Expect hit once of ipex.optimize globally"52assert _ipex_convolution, "Expect use ipex convolution by ipex.optimize"53assert _has_batchnorm is False, "should not see bn"54
55def test_conv_bn_with_module_created_in_forward(self):56loc = os.path.dirname(os.path.abspath(__file__))57disable_ipex_graph_modes = [False, True]58dtypes = (59["float32", "bfloat16"]60if torch.ops.mkldnn._is_mkldnn_bf16_supported()61else ["float32"]62)63for disable_ipex_graph_mode, dtype in itertools.product(64disable_ipex_graph_modes, dtypes65):66_ipex_optimize_hit_count = 067_ipex_convolution = False68cmd = "ipexrun --ninstances 1 "69cmd += "--auto-ipex "70cmd += "--dtype {} ".format(dtype)71cmd += "--auto-ipex-verbose "72if disable_ipex_graph_mode:73cmd += "--disable-ipex-graph-mode "74cmd += "{}/code_free_optimization.py --conv_bn_with_module_created_in_forward".format(75loc
76)77with subprocess.Popen(78cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT79) as p:80for line in p.stdout.readlines():81line = str(line, "utf-8").strip()82if line.__contains__("_ipex_optimize_hit_count"):83_ipex_optimize_hit_count = _ipex_optimize_hit_count + 184if line.__contains__(85"torch_ipex::convolution_forward_impl"86if disable_ipex_graph_mode87else "ipex_prepack::convolution_run"88) or line.__contains__("fused_to"):89_ipex_convolution = True90assert (91_ipex_optimize_hit_count == 192), "Expect hit once of ipex.optimize globally"93assert _ipex_convolution, "Expect use ipex convolution by ipex.optimize"94# Not check BN, because FX limitation, ipex.optimize failed to do fusion95
96def test_auto_ipex_module(self):97loc = os.path.dirname(os.path.abspath(__file__))98disable_ipex_graph_modes = [False, True]99dtypes = (100["float32", "bfloat16"]101if torch.ops.mkldnn._is_mkldnn_bf16_supported()102else ["float32"]103)104for disable_ipex_graph_mode, dtype in itertools.product(105disable_ipex_graph_modes, dtypes106):107_ipex_optimize_hit_count = 0108_ipex_convolution = False109_has_batchnorm = False110cmd = "python -m intel_extension_for_pytorch.cpu.auto_ipex "111cmd += "--dtype {} ".format(dtype)112cmd += "--auto-ipex-verbose "113if disable_ipex_graph_mode:114cmd += "--disable-ipex-graph-mode "115cmd += "{}/code_free_optimization.py --conv_bn".format(loc)116with subprocess.Popen(117cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT118) as p:119for line in p.stdout.readlines():120line = str(line, "utf-8").strip()121if line.__contains__("_ipex_optimize_hit_count"):122_ipex_optimize_hit_count = _ipex_optimize_hit_count + 1123if line.__contains__(124"torch_ipex::convolution_forward_impl"125if disable_ipex_graph_mode126else "ipex_prepack::convolution_run"127) or line.__contains__("fused_to"):128_ipex_convolution = True129if line.__contains__("batch_norm"):130_has_batchnorm = True131assert (132_ipex_optimize_hit_count == 1133), "Expect hit once of ipex.optimize globally"134assert _ipex_convolution, "Expect use ipex convolution by ipex.optimize"135assert _has_batchnorm is False, "should not see bn"136
137
138if __name__ == "__main__":139test = unittest.main()140