intel-extension-for-pytorch

Форк
0
163 строки · 4.7 Кб
1
import platform
2
import subprocess
3
import os
4
import sys
5
import logging
6
from tempfile import mkstemp
7
import uuid
8
from argparse import ArgumentParser, REMAINDER
9
from argparse import RawTextHelpFormatter
10

11

12
format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
13
logging.basicConfig(level=logging.INFO, format=format_str)
14
logger = logging.getLogger(__name__)
15

16

17
def apply_monkey_patch(args):
18
    # Auto apply the ipex features
19
    # Open the original file and get the content
20
    program = args.program
21
    with open(program) as f:
22
        original_program_lines = f.readlines()
23

24
    # Modify the content with import ipex
25
    monkey_patch = """import torch
26
import intel_extension_for_pytorch as ipex
27
"""
28
    if args.convert_fp64_to_fp32:
29
        monkey_patch += """ipex.xpu.overrides.convert_default_dtype(torch.float64, torch.float32, True)
30
"""
31
    original_program_lines.insert(0, monkey_patch)
32

33
    program_absolute_path = os.path.abspath(program)
34
    program_absolute_path_dir = os.path.dirname(program_absolute_path)
35
    generate_file_suffix = (
36
        str(hash(program_absolute_path)) + str(uuid.uuid1()) + "_auto_ipex"
37
    )
38
    _, generate_file = mkstemp(
39
        suffix=generate_file_suffix, dir=program_absolute_path_dir, text=True
40
    )
41

42
    # Write the monkey_patched content to temp file
43
    with open(generate_file, "w") as f:
44
        f.writelines(original_program_lines)
45

46
    return generate_file
47

48

49
class Launcher:
50
    r"""
51
    Base class for launcher
52
    """
53

54
    def __init__(self):
55
        pass
56

57
    def launch(self, args):
58
        pass
59

60
    def logger_env(self, env_name=""):
61
        if env_name in os.environ:
62
            logger.info("{}={}".format(env_name, os.environ[env_name]))
63

64
    def set_env(self, env_name, env_value=None):
65
        if not env_value:
66
            logger.warning("{} is None".format(env_name))
67
        if env_name not in os.environ:
68
            os.environ[env_name] = env_value
69
        elif os.environ[env_name] != env_value:
70
            logger.warning(
71
                "{} in environment variable is {} while the value you set is {}".format(
72
                    env_name, os.environ[env_name], env_value
73
                )
74
            )
75
        self.logger_env(env_name)
76

77

78
class XPUDefaultLauncher(Launcher):
79
    """
80
    Run the program using XPU.
81
    # Note: For now, we only support single instance in this script
82
    """
83

84
    def launch(self, args):
85
        processes = []
86
        cmd = []
87

88
        monkey_program = apply_monkey_patch(args)
89

90
        cmd.append(sys.executable)
91
        cmd.append(monkey_program)
92
        cmd.extend(args.program_args)
93

94
        cmd_s = " ".join(cmd)
95
        process = subprocess.Popen(cmd_s, env=os.environ, shell=True)
96
        processes.append(process)
97
        try:
98
            for process in processes:
99
                process.wait()
100
                if process.returncode != 0:
101
                    raise subprocess.CalledProcessError(
102
                        returncode=process.returncode, cmd=cmd_s
103
                    )
104
        except subprocess.CalledProcessError as e:
105
            print(e.output)
106
        finally:
107
            os.remove(monkey_program)
108

109

110
def init_parser(parser):
111
    """
112
    Helper function parsing the command line options
113
    @retval ArgumentParser
114
    """
115

116
    # positional
117
    parser.add_argument(
118
        "--convert-fp64-to-fp32",
119
        "--convert_fp64_to_fp32",
120
        action="store_true",
121
        dest="convert_fp64_to_fp32",
122
        help="To automatically convert torch.float64(double) dtype to torch.float32",
123
    )
124
    parser.add_argument(
125
        "program",
126
        type=str,
127
        help="The full path to the proram/script to be launched. "
128
        "followed by all the arguments for the script",
129
    )
130

131
    # rest from the training program
132
    parser.add_argument("program_args", nargs=REMAINDER)
133
    return parser
134

135

136
def run_main_with_args(args):
137
    env_before = set(os.environ.keys())
138
    if platform.system() == "Windows":
139
        raise RuntimeError("Windows platform is not supported!!!")
140
    launcher = None
141
    launcher = XPUDefaultLauncher()
142
    launcher.launch(args)
143
    for x in sorted(set(os.environ.keys()) - env_before):
144
        logger.debug("{0}={1}".format(x, os.environ[x]))
145

146

147
def main():
148
    parser = ArgumentParser(
149
        description="This is a script for launching PyTorch training and inference on Intel GPU Series"
150
        "with optimal configurations. "
151
        "\n################################# Basic usage ############################# \n"
152
        "\n 1. Run with args\n"
153
        "\n   >>> ipexrun xpu python_script args \n"
154
        "\n############################################################################# \n",
155
        formatter_class=RawTextHelpFormatter,
156
    )
157
    parser = init_parser(parser)
158
    args = parser.parse_args()
159
    run_main_with_args(args)
160

161

162
if __name__ == "__main__":
163
    main()
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.