pytorch
95 строк · 2.3 Кб
1import torch2from torch.export import Dim3
4
5# custom op that loads the aot-compiled model
6AOTI_CUSTOM_OP_LIB = "libaoti_custom_class.so"7torch.classes.load_library(AOTI_CUSTOM_OP_LIB)8
9
10class TensorSerializer(torch.nn.Module):11def __init__(self, data):12super().__init__()13for key in data:14setattr(self, key, data[key])15
16
17class SimpleModule(torch.nn.Module):18"""19a simple module to be compiled
20"""
21
22def __init__(self) -> None:23super().__init__()24self.fc = torch.nn.Linear(4, 6)25self.relu = torch.nn.ReLU()26
27def forward(self, x):28a = self.fc(x)29b = self.relu(a)30return b31
32
33class MyAOTIModule(torch.nn.Module):34"""35a wrapper nn.Module that instantiates its forward method
36on MyAOTIClass
37"""
38
39def __init__(self, lib_path, device):40super().__init__()41self.aoti_custom_op = torch.classes.aoti.MyAOTIClass(42lib_path,43device,44)45
46def forward(self, *x):47outputs = self.aoti_custom_op.forward(x)48return tuple(outputs)49
50
51def make_script_module(lib_path, device, *inputs):52m = MyAOTIModule(lib_path, device)53# sanity check54m(*inputs)55return torch.jit.trace(m, inputs)56
57
58def compile_model(device, data):59module = SimpleModule().to(device)60x = torch.randn((4, 4), device=device)61inputs = (x,)62# make batch dimension63batch_dim = Dim("batch", min=1, max=1024)64dynamic_shapes = {65"x": {0: batch_dim},66}67with torch.no_grad():68# aot-compile the module into a .so pointed by lib_path69lib_path = torch._export.aot_compile(70module, inputs, dynamic_shapes=dynamic_shapes71)72script_module = make_script_module(lib_path, device, *inputs)73aoti_script_model = f"script_model_{device}.pt"74script_module.save(aoti_script_model)75
76# save sample inputs and ref output77with torch.no_grad():78ref_output = module(*inputs)79data.update(80{81f"inputs_{device}": list(inputs),82f"outputs_{device}": [ref_output],83}84)85
86
87def main():88data = {}89for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:90compile_model(device, data)91torch.jit.script(TensorSerializer(data)).save("script_data.pt")92
93
94if __name__ == "__main__":95main()96