intel-extension-for-pytorch
49 строк · 1016.0 Байт
1import torch2import intel_extension_for_pytorch as ipex3
4torch_function = [5"rand",6"randint",7"arange",8"bartlett_window",9"blackman_window",10"empty",11"_empty_affine_quantized",12"_empty_per_channel_affine_quantized",13"empty_strided",14"eye",15"full",16"from_file",17"from_numpy",18"hann_window",19"hamming_window",20"linspace",21"logspace",22"ones",23"scalar_tensor",24"randn",25"randperm",26"range",27"zeros",28"sparse_coo_tensor",29"tril_indices",30"triu_indices",31"normal",32"tensor",33]
34
35
36def make_hooked_func(torch_func):37def hooked_func(*args, **kwargs):38if "device" in kwargs:39return torch_func(*args, **kwargs)40else:41return torch_func(*args, **kwargs).to(ipex.DEVICE)42
43return hooked_func44
45
46for torch_func_name in torch_function:47torch_fn = getattr(torch, torch_func_name)48hooked_fn = make_hooked_func(torch_fn)49setattr(torch, torch_func_name, hooked_fn)50