intel-extension-for-pytorch
48 строк · 1.2 Кб
1import torch2import intel_extension_for_pytorch as ipex3from common_utils import int8_calibration4
5ipex.core.enable_auto_dnnl()6
7ic = 10248oc = 10249bs = 1610
11LL = torch.nn.Linear(ic, oc).to(ipex.DEVICE)12
13
14def get_input():15return torch.rand(bs, ic).to(ipex.DEVICE)16
17
18def run_linear(auto_mix_conf=None):19for i in range(3):20if auto_mix_conf is not None:21with ipex.AutoMixPrecision(auto_mix_conf):22LL(get_input())23else:24LL(get_input())25
26
27if __name__ == "__main__":28print(f"fp32, {'*' * 50}")29run_linear()30
31print(f"auto-mix for bf16, {'*' * 50}")32bf16_conf = ipex.AmpConf(torch.bfloat16)33run_linear(bf16_conf)34
35print(f"back to fp32, {'*' * 50}")36ipex.core.reorder_to_float32(LL.weight)37ipex.core.reorder_to_float32(LL.bias)38run_linear()39
40print(f"auto-mix for int8, {'*' * 50}")41int8_calibration(LL, [get_input() for i in range(3)], "./int8.config")42int8_conf = ipex.AmpConf(torch.int8, "./int8.config")43run_linear(int8_conf)44
45print(f"back to fp32, {'*' * 50}")46ipex.core.reorder_to_float32(LL.weight)47ipex.core.reorder_to_float32(LL.bias)48run_linear()49