pytorch
36 строк · 1.4 Кб
1import torch2
3
4def set_fuser(fuser_name, executor_name):5assert fuser_name in ["te", "old", "none", "default"]6if fuser_name == "te":7torch._C._jit_set_profiling_executor(True)8torch._C._get_graph_executor_optimize(True)9torch._C._jit_override_can_fuse_on_cpu(False)10torch._C._jit_override_can_fuse_on_gpu(True)11torch._C._jit_set_texpr_fuser_enabled(True)12elif fuser_name == "old":13torch._C._jit_set_profiling_executor(False)14torch._C._get_graph_executor_optimize(False)15torch._C._jit_override_can_fuse_on_gpu(True)16torch._C._jit_set_texpr_fuser_enabled(False)17elif fuser_name == "none":18torch._C._jit_set_profiling_executor(False)19torch._C._get_graph_executor_optimize(False)20torch._C._jit_override_can_fuse_on_gpu(False)21torch._C._jit_override_can_fuse_on_cpu(False)22torch._C._jit_set_texpr_fuser_enabled(False)23elif fuser_name == "default":24pass25
26# --executor overrides settings of --fuser27if executor_name == "profiling":28torch._C._jit_set_profiling_executor(True)29torch._C._get_graph_executor_optimize(True)30elif executor_name == "simple":31torch._C._get_graph_executor_optimize(False)32elif executor_name == "legacy":33torch._C._jit_set_profiling_executor(False)34torch._C._get_graph_executor_optimize(True)35elif executor_name == "default":36pass37