pytorch

Форк
0
/
compile_model.py 
95 строк · 2.3 Кб
1
import torch
2
from torch.export import Dim
3

4

5
# custom op that loads the aot-compiled model
6
AOTI_CUSTOM_OP_LIB = "libaoti_custom_class.so"
7
torch.classes.load_library(AOTI_CUSTOM_OP_LIB)
8

9

10
class TensorSerializer(torch.nn.Module):
11
    def __init__(self, data):
12
        super().__init__()
13
        for key in data:
14
            setattr(self, key, data[key])
15

16

17
class SimpleModule(torch.nn.Module):
18
    """
19
    a simple module to be compiled
20
    """
21

22
    def __init__(self) -> None:
23
        super().__init__()
24
        self.fc = torch.nn.Linear(4, 6)
25
        self.relu = torch.nn.ReLU()
26

27
    def forward(self, x):
28
        a = self.fc(x)
29
        b = self.relu(a)
30
        return b
31

32

33
class MyAOTIModule(torch.nn.Module):
34
    """
35
    a wrapper nn.Module that instantiates its forward method
36
    on MyAOTIClass
37
    """
38

39
    def __init__(self, lib_path, device):
40
        super().__init__()
41
        self.aoti_custom_op = torch.classes.aoti.MyAOTIClass(
42
            lib_path,
43
            device,
44
        )
45

46
    def forward(self, *x):
47
        outputs = self.aoti_custom_op.forward(x)
48
        return tuple(outputs)
49

50

51
def make_script_module(lib_path, device, *inputs):
52
    m = MyAOTIModule(lib_path, device)
53
    # sanity check
54
    m(*inputs)
55
    return torch.jit.trace(m, inputs)
56

57

58
def compile_model(device, data):
59
    module = SimpleModule().to(device)
60
    x = torch.randn((4, 4), device=device)
61
    inputs = (x,)
62
    # make batch dimension
63
    batch_dim = Dim("batch", min=1, max=1024)
64
    dynamic_shapes = {
65
        "x": {0: batch_dim},
66
    }
67
    with torch.no_grad():
68
        # aot-compile the module into a .so pointed by lib_path
69
        lib_path = torch._export.aot_compile(
70
            module, inputs, dynamic_shapes=dynamic_shapes
71
        )
72
    script_module = make_script_module(lib_path, device, *inputs)
73
    aoti_script_model = f"script_model_{device}.pt"
74
    script_module.save(aoti_script_model)
75

76
    # save sample inputs and ref output
77
    with torch.no_grad():
78
        ref_output = module(*inputs)
79
    data.update(
80
        {
81
            f"inputs_{device}": list(inputs),
82
            f"outputs_{device}": [ref_output],
83
        }
84
    )
85

86

87
def main():
88
    data = {}
89
    for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
90
        compile_model(device, data)
91
    torch.jit.script(TensorSerializer(data)).save("script_data.pt")
92

93

94
if __name__ == "__main__":
95
    main()
96

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.