intel-extension-for-pytorch

Форк
0
54 строки · 1.8 Кб
1
import torch
2
import intel_extension_for_pytorch as ipex
3
import torch.nn as nn
4
import itertools
5

6

7
class Model(nn.Module):
8
    def __init__(self, ic, oc, bias):
9
        super(Model, self).__init__()
10
        self.linear = nn.Linear(ic, oc, bias=bias)
11

12
    def forward(self, input):
13
        return self.linear(input)
14

15

16
def run_model(dtype=None):
17
    out_feature = [1024, 256, 1, torch.randint(3, 10, (1,)).item()]
18
    in_feature = [128, 479, torch.randint(3, 10, (1,)).item()]
19
    input_shapes = []
20
    for s in in_feature:
21
        input_shapes += [(128, s), (2, 64, s), (2, 2, 32, s)]
22
    options = itertools.product(out_feature, [True, False], input_shapes)
23
    for out_features, bias, x_shape in options:
24
        in_features = x_shape[-1]
25
        x = torch.randn(x_shape, dtype=torch.float32).requires_grad_()
26
        model = Model(in_features, out_features, bias)
27
        optimizer = torch.optim.Adagrad(model.parameters(), lr=0.1)
28
        if dtype == 0:
29
            conf = ipex.AmpConf(torch.float32)
30
            model, optimizer = ipex.optimize(
31
                model, dtype=torch.float32, optimizer=optimizer, level="O1"
32
            )
33
            with ipex.amp.autocast(enabled=True, configure=conf):
34
                run_mod = model.forward(x).sum()
35
        elif dtype == 1:
36
            conf = ipex.AmpConf(torch.bfloat16)
37
            model, optimizer = ipex.optimize(
38
                model, dtype=torch.bfloat16, optimizer=optimizer, level="O1"
39
            )
40
            with ipex.amp.autocast(enabled=True, configure=conf):
41
                run_mod = model.forward(x).sum()
42
        else:  # reserved
43
            pass
44
        optimizer.zero_grad()
45
        run_mod.backward()
46
        optimizer.step()
47

48

49
if __name__ == "__main__":
50
    print(f"fp32, {'*' * 50}")
51
    run_model(0)
52

53
    print(f"bf16, {'*' * 50}")
54
    run_model(1)
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.