intel-extension-for-pytorch
54 строки · 1.8 Кб
1import torch
2import intel_extension_for_pytorch as ipex
3import torch.nn as nn
4import itertools
5
6
7class Model(nn.Module):
8def __init__(self, ic, oc, bias):
9super(Model, self).__init__()
10self.linear = nn.Linear(ic, oc, bias=bias)
11
12def forward(self, input):
13return self.linear(input)
14
15
16def run_model(dtype=None):
17out_feature = [1024, 256, 1, torch.randint(3, 10, (1,)).item()]
18in_feature = [128, 479, torch.randint(3, 10, (1,)).item()]
19input_shapes = []
20for s in in_feature:
21input_shapes += [(128, s), (2, 64, s), (2, 2, 32, s)]
22options = itertools.product(out_feature, [True, False], input_shapes)
23for out_features, bias, x_shape in options:
24in_features = x_shape[-1]
25x = torch.randn(x_shape, dtype=torch.float32).requires_grad_()
26model = Model(in_features, out_features, bias)
27optimizer = torch.optim.Adagrad(model.parameters(), lr=0.1)
28if dtype == 0:
29conf = ipex.AmpConf(torch.float32)
30model, optimizer = ipex.optimize(
31model, dtype=torch.float32, optimizer=optimizer, level="O1"
32)
33with ipex.amp.autocast(enabled=True, configure=conf):
34run_mod = model.forward(x).sum()
35elif dtype == 1:
36conf = ipex.AmpConf(torch.bfloat16)
37model, optimizer = ipex.optimize(
38model, dtype=torch.bfloat16, optimizer=optimizer, level="O1"
39)
40with ipex.amp.autocast(enabled=True, configure=conf):
41run_mod = model.forward(x).sum()
42else: # reserved
43pass
44optimizer.zero_grad()
45run_mod.backward()
46optimizer.step()
47
48
49if __name__ == "__main__":
50print(f"fp32, {'*' * 50}")
51run_model(0)
52
53print(f"bf16, {'*' * 50}")
54run_model(1)
55