11
from torch.autograd import Variable
12
from torch.onnx import OperatorExportTypes
15
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16
sys.path.append(pytorch_test_dir)
17
import pytorch_test_common
19
from torch.testing._internal import common_utils
23
class TestExportModes(pytorch_test_common.ExportTestCase):
24
class MyModel(nn.Module):
26
super(TestExportModes.MyModel, self).__init__()
29
return x.transpose(0, 1)
31
def test_protobuf(self):
32
torch_model = TestExportModes.MyModel()
33
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
40
export_type=torch.onnx.ExportTypes.PROTOBUF_FILE,
43
def test_zipfile(self):
44
torch_model = TestExportModes.MyModel()
45
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
52
export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE,
55
def test_compressed_zipfile(self):
56
torch_model = TestExportModes.MyModel()
57
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
64
export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
67
def test_directory(self):
68
torch_model = TestExportModes.MyModel()
69
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
70
d = tempfile.mkdtemp()
76
export_type=torch.onnx.ExportTypes.DIRECTORY,
80
def test_onnx_multiple_return(self):
87
torch.onnx.export(foo, (x,), f)
89
@common_utils.skipIfNoCaffe2
90
@common_utils.skipIfNoLapack
91
def test_caffe2_aten_fallback(self):
92
class ModelWithAtenNotONNXOp(nn.Module):
93
def forward(self, x, y):
95
defg = torch.linalg.qr(abcd)
100
torch.onnx.export_to_pretty_string(
101
ModelWithAtenNotONNXOp(),
103
add_node_names=False,
104
do_constant_folding=False,
105
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
108
@common_utils.skipIfCaffe2
109
@common_utils.skipIfNoLapack
110
def test_aten_fallback(self):
111
class ModelWithAtenNotONNXOp(nn.Module):
112
def forward(self, x, y):
114
defg = torch.linalg.qr(abcd)
119
torch.onnx.export_to_pretty_string(
120
ModelWithAtenNotONNXOp(),
122
add_node_names=False,
123
do_constant_folding=False,
124
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
132
def test_onnx_aten(self):
133
class ModelWithAtenFmod(nn.Module):
134
def forward(self, x, y):
135
return torch.fmod(x, y)
137
x = torch.randn(3, 4, dtype=torch.float32)
138
y = torch.randn(3, 4, dtype=torch.float32)
139
torch.onnx.export_to_pretty_string(
142
add_node_names=False,
143
do_constant_folding=False,
144
operator_export_type=OperatorExportTypes.ONNX_ATEN,
148
if __name__ == "__main__":
149
common_utils.run_tests()