pytorch
1# Usage: python create_dummy_model.py <name_of_the_file>
2import sys3
4import torch5from torch import nn6
7
8class NeuralNetwork(nn.Module):9def __init__(self) -> None:10super().__init__()11self.flatten = nn.Flatten()12self.linear_relu_stack = nn.Sequential(13nn.Linear(28 * 28, 512),14nn.ReLU(),15nn.Linear(512, 512),16nn.ReLU(),17nn.Linear(512, 10),18)19
20def forward(self, x):21x = self.flatten(x)22logits = self.linear_relu_stack(x)23return logits24
25
26if __name__ == "__main__":27jit_module = torch.jit.script(NeuralNetwork())28torch.jit.save(jit_module, sys.argv[1])29orig_module = nn.Sequential(30nn.Linear(28 * 28, 512),31nn.ReLU(),32nn.Linear(512, 512),33nn.ReLU(),34nn.Linear(512, 10),35)36torch.save(orig_module, sys.argv[1] + ".orig")37