intel-extension-for-pytorch
163 строки · 4.7 Кб
1import platform2import subprocess3import os4import sys5import logging6from tempfile import mkstemp7import uuid8from argparse import ArgumentParser, REMAINDER9from argparse import RawTextHelpFormatter10
11
12format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"13logging.basicConfig(level=logging.INFO, format=format_str)14logger = logging.getLogger(__name__)15
16
17def apply_monkey_patch(args):18# Auto apply the ipex features19# Open the original file and get the content20program = args.program21with open(program) as f:22original_program_lines = f.readlines()23
24# Modify the content with import ipex25monkey_patch = """import torch26import intel_extension_for_pytorch as ipex
27"""
28if args.convert_fp64_to_fp32:29monkey_patch += """ipex.xpu.overrides.convert_default_dtype(torch.float64, torch.float32, True)30"""
31original_program_lines.insert(0, monkey_patch)32
33program_absolute_path = os.path.abspath(program)34program_absolute_path_dir = os.path.dirname(program_absolute_path)35generate_file_suffix = (36str(hash(program_absolute_path)) + str(uuid.uuid1()) + "_auto_ipex"37)38_, generate_file = mkstemp(39suffix=generate_file_suffix, dir=program_absolute_path_dir, text=True40)41
42# Write the monkey_patched content to temp file43with open(generate_file, "w") as f:44f.writelines(original_program_lines)45
46return generate_file47
48
49class Launcher:50r"""51Base class for launcher
52"""
53
54def __init__(self):55pass56
57def launch(self, args):58pass59
60def logger_env(self, env_name=""):61if env_name in os.environ:62logger.info("{}={}".format(env_name, os.environ[env_name]))63
64def set_env(self, env_name, env_value=None):65if not env_value:66logger.warning("{} is None".format(env_name))67if env_name not in os.environ:68os.environ[env_name] = env_value69elif os.environ[env_name] != env_value:70logger.warning(71"{} in environment variable is {} while the value you set is {}".format(72env_name, os.environ[env_name], env_value73)74)75self.logger_env(env_name)76
77
78class XPUDefaultLauncher(Launcher):79"""80Run the program using XPU.
81# Note: For now, we only support single instance in this script
82"""
83
84def launch(self, args):85processes = []86cmd = []87
88monkey_program = apply_monkey_patch(args)89
90cmd.append(sys.executable)91cmd.append(monkey_program)92cmd.extend(args.program_args)93
94cmd_s = " ".join(cmd)95process = subprocess.Popen(cmd_s, env=os.environ, shell=True)96processes.append(process)97try:98for process in processes:99process.wait()100if process.returncode != 0:101raise subprocess.CalledProcessError(102returncode=process.returncode, cmd=cmd_s103)104except subprocess.CalledProcessError as e:105print(e.output)106finally:107os.remove(monkey_program)108
109
110def init_parser(parser):111"""112Helper function parsing the command line options
113@retval ArgumentParser
114"""
115
116# positional117parser.add_argument(118"--convert-fp64-to-fp32",119"--convert_fp64_to_fp32",120action="store_true",121dest="convert_fp64_to_fp32",122help="To automatically convert torch.float64(double) dtype to torch.float32",123)124parser.add_argument(125"program",126type=str,127help="The full path to the proram/script to be launched. "128"followed by all the arguments for the script",129)130
131# rest from the training program132parser.add_argument("program_args", nargs=REMAINDER)133return parser134
135
136def run_main_with_args(args):137env_before = set(os.environ.keys())138if platform.system() == "Windows":139raise RuntimeError("Windows platform is not supported!!!")140launcher = None141launcher = XPUDefaultLauncher()142launcher.launch(args)143for x in sorted(set(os.environ.keys()) - env_before):144logger.debug("{0}={1}".format(x, os.environ[x]))145
146
147def main():148parser = ArgumentParser(149description="This is a script for launching PyTorch training and inference on Intel GPU Series"150"with optimal configurations. "151"\n################################# Basic usage ############################# \n"152"\n 1. Run with args\n"153"\n >>> ipexrun xpu python_script args \n"154"\n############################################################################# \n",155formatter_class=RawTextHelpFormatter,156)157parser = init_parser(parser)158args = parser.parse_args()159run_main_with_args(args)160
161
162if __name__ == "__main__":163main()164