pytorch

Форк
0
115 строк · 3.4 Кб
1
import torch
2
from torch._export import aot_compile
3
from torch.export import Dim
4

5

6
torch.manual_seed(1337)
7

8

9
class Net(torch.nn.Module):
10
    def __init__(self, device):
11
        super().__init__()
12
        self.w_pre = torch.randn(4, 4, device=device)
13
        self.w_add = torch.randn(4, 4, device=device)
14

15
    def forward(self, x):
16
        w_transpose = torch.transpose(self.w_pre, 0, 1)
17
        w_relu = torch.nn.functional.relu(w_transpose)
18
        w = w_relu + self.w_add
19
        return torch.matmul(x, w)
20

21

22
class NetWithTensorConstants(torch.nn.Module):
23
    def __init__(self) -> None:
24
        super().__init__()
25
        self.w = torch.randn(30, 1, device="cuda")
26

27
    def forward(self, x, y):
28
        z = self.w * x * y
29
        return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
30

31

32
data = {}
33
data_with_tensor_constants = {}
34

35

36
# Basice AOTI model test generation.
37
def generate_basic_tests():
38
    for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
39
        for use_runtime_constant_folding in [True, False]:
40
            if device == "cpu" and use_runtime_constant_folding:
41
                # We do not test runtime const folding for cpu mode.
42
                continue
43
            model = Net(device).to(device=device)
44
            x = torch.randn((4, 4), device=device)
45
            with torch.no_grad():
46
                ref_output = model(x)
47

48
            torch._dynamo.reset()
49
            with torch.no_grad():
50
                dim0_x = Dim("dim0_x", min=1, max=1024)
51
                dynamic_shapes = {"x": {0: dim0_x}}
52
                model_so_path = aot_compile(
53
                    model,
54
                    (x,),
55
                    dynamic_shapes=dynamic_shapes,
56
                    options={
57
                        "aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
58
                    },
59
                )
60

61
            suffix = f"{device}"
62
            if use_runtime_constant_folding:
63
                suffix += "_use_runtime_constant_folding"
64
            data.update(
65
                {
66
                    f"model_so_path_{suffix}": model_so_path,
67
                    f"inputs_{suffix}": [x],
68
                    f"outputs_{suffix}": [ref_output],
69
                    f"w_pre_{suffix}": model.w_pre,
70
                    f"w_add_{suffix}": model.w_add,
71
                }
72
            )
73

74

75
# AOTI model which will create additional tensors during autograd.
76
def generate_test_with_additional_tensors():
77
    if not torch.cuda.is_available():
78
        return
79

80
    model = NetWithTensorConstants()
81
    x = torch.randn((30, 1), device="cuda")
82
    y = torch.randn((30, 1), device="cuda")
83
    with torch.no_grad():
84
        ref_output = model(x, y)
85

86
    torch._dynamo.reset()
87
    with torch.no_grad():
88
        model_so_path = aot_compile(model, (x, y))
89

90
    data_with_tensor_constants.update(
91
        {
92
            "model_so_path": model_so_path,
93
            "inputs": [x, y],
94
            "outputs": [ref_output],
95
            "w": model.w,
96
        }
97
    )
98

99

100
generate_basic_tests()
101
generate_test_with_additional_tensors()
102

103

104
# Use this to communicate tensors to the cpp code
105
class Serializer(torch.nn.Module):
106
    def __init__(self, data):
107
        super().__init__()
108
        for key in data:
109
            setattr(self, key, data[key])
110

111

112
torch.jit.script(Serializer(data)).save("data.pt")
113
torch.jit.script(Serializer(data_with_tensor_constants)).save(
114
    "data_with_tensor_constants.pt"
115
)
116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.