intel-extension-for-pytorch

Форк
0
/
test_code_free_optimization.py 
139 строк · 5.9 Кб
1
import unittest
2
from common_utils import TestCase
3
import os
4
import subprocess
5
import itertools
6
import torch
7
import logging
8

9
logging.getLogger().setLevel(logging.DEBUG)
10

11

12
class TestCodeFreeOptimization(TestCase):
13
    def test_conv_bn(self):
14
        loc = os.path.dirname(os.path.abspath(__file__))
15
        disable_ipex_graph_modes = [False, True]
16
        dtypes = (
17
            ["float32", "bfloat16"]
18
            if torch.ops.mkldnn._is_mkldnn_bf16_supported()
19
            else ["float32"]
20
        )
21
        for disable_ipex_graph_mode, dtype in itertools.product(
22
            disable_ipex_graph_modes, dtypes
23
        ):
24
            _ipex_optimize_hit_count = 0
25
            _ipex_convolution = False
26
            _has_batchnorm = False
27
            cmd = "ipexrun --ninstances 1 "
28
            cmd += "--auto-ipex "
29
            cmd += "--dtype {} ".format(dtype)
30
            cmd += "--auto-ipex-verbose "
31
            if disable_ipex_graph_mode:
32
                cmd += "--disable-ipex-graph-mode "
33
            cmd += "{}/code_free_optimization.py --conv_bn".format(loc)
34
            with subprocess.Popen(
35
                cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
36
            ) as p:
37
                for line in p.stdout.readlines():
38
                    line = str(line, "utf-8").strip()
39
                    if line.__contains__("_ipex_optimize_hit_count"):
40
                        _ipex_optimize_hit_count = _ipex_optimize_hit_count + 1
41
                    if line.__contains__(
42
                        "torch_ipex::convolution_forward_impl"
43
                        if disable_ipex_graph_mode
44
                        else "ipex_prepack::convolution_run"
45
                    ) or line.__contains__("fused_to"):
46
                        _ipex_convolution = True
47
                    if line.__contains__("batch_norm"):
48
                        _has_batchnorm = True
49
            assert (
50
                _ipex_optimize_hit_count == 1
51
            ), "Expect hit once of ipex.optimize globally"
52
            assert _ipex_convolution, "Expect use ipex convolution by ipex.optimize"
53
            assert _has_batchnorm is False, "should not see bn"
54

55
    def test_conv_bn_with_module_created_in_forward(self):
56
        loc = os.path.dirname(os.path.abspath(__file__))
57
        disable_ipex_graph_modes = [False, True]
58
        dtypes = (
59
            ["float32", "bfloat16"]
60
            if torch.ops.mkldnn._is_mkldnn_bf16_supported()
61
            else ["float32"]
62
        )
63
        for disable_ipex_graph_mode, dtype in itertools.product(
64
            disable_ipex_graph_modes, dtypes
65
        ):
66
            _ipex_optimize_hit_count = 0
67
            _ipex_convolution = False
68
            cmd = "ipexrun --ninstances 1 "
69
            cmd += "--auto-ipex "
70
            cmd += "--dtype {} ".format(dtype)
71
            cmd += "--auto-ipex-verbose "
72
            if disable_ipex_graph_mode:
73
                cmd += "--disable-ipex-graph-mode "
74
            cmd += "{}/code_free_optimization.py --conv_bn_with_module_created_in_forward".format(
75
                loc
76
            )
77
            with subprocess.Popen(
78
                cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
79
            ) as p:
80
                for line in p.stdout.readlines():
81
                    line = str(line, "utf-8").strip()
82
                    if line.__contains__("_ipex_optimize_hit_count"):
83
                        _ipex_optimize_hit_count = _ipex_optimize_hit_count + 1
84
                    if line.__contains__(
85
                        "torch_ipex::convolution_forward_impl"
86
                        if disable_ipex_graph_mode
87
                        else "ipex_prepack::convolution_run"
88
                    ) or line.__contains__("fused_to"):
89
                        _ipex_convolution = True
90
            assert (
91
                _ipex_optimize_hit_count == 1
92
            ), "Expect hit once of ipex.optimize globally"
93
            assert _ipex_convolution, "Expect use ipex convolution by ipex.optimize"
94
            # Not check BN, because FX limitation, ipex.optimize failed to do fusion
95

96
    def test_auto_ipex_module(self):
97
        loc = os.path.dirname(os.path.abspath(__file__))
98
        disable_ipex_graph_modes = [False, True]
99
        dtypes = (
100
            ["float32", "bfloat16"]
101
            if torch.ops.mkldnn._is_mkldnn_bf16_supported()
102
            else ["float32"]
103
        )
104
        for disable_ipex_graph_mode, dtype in itertools.product(
105
            disable_ipex_graph_modes, dtypes
106
        ):
107
            _ipex_optimize_hit_count = 0
108
            _ipex_convolution = False
109
            _has_batchnorm = False
110
            cmd = "python -m intel_extension_for_pytorch.cpu.auto_ipex "
111
            cmd += "--dtype {} ".format(dtype)
112
            cmd += "--auto-ipex-verbose "
113
            if disable_ipex_graph_mode:
114
                cmd += "--disable-ipex-graph-mode "
115
            cmd += "{}/code_free_optimization.py --conv_bn".format(loc)
116
            with subprocess.Popen(
117
                cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
118
            ) as p:
119
                for line in p.stdout.readlines():
120
                    line = str(line, "utf-8").strip()
121
                    if line.__contains__("_ipex_optimize_hit_count"):
122
                        _ipex_optimize_hit_count = _ipex_optimize_hit_count + 1
123
                    if line.__contains__(
124
                        "torch_ipex::convolution_forward_impl"
125
                        if disable_ipex_graph_mode
126
                        else "ipex_prepack::convolution_run"
127
                    ) or line.__contains__("fused_to"):
128
                        _ipex_convolution = True
129
                    if line.__contains__("batch_norm"):
130
                        _has_batchnorm = True
131
            assert (
132
                _ipex_optimize_hit_count == 1
133
            ), "Expect hit once of ipex.optimize globally"
134
            assert _ipex_convolution, "Expect use ipex convolution by ipex.optimize"
135
            assert _has_batchnorm is False, "should not see bn"
136

137

138
if __name__ == "__main__":
139
    test = unittest.main()
140

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.