3
import pytorch_test_common
6
from onnx_test_common import run_model_test
7
from torch.onnx import OperatorExportTypes
8
from torch.onnx._globals import GLOBALS
9
from torch.onnx.utils import _model_to_graph
10
from torch.testing._internal import common_utils
13
class TestAutogradFuns(pytorch_test_common.ExportTestCase):
14
opset_version = GLOBALS.export_onnx_opset_version
15
keep_initializers_as_inputs = False
16
onnx_shape_inference = True
18
def test_single_output(self):
19
class SingleOut(torch.autograd.Function):
24
ctx.save_for_backward(result)
28
def backward(ctx, grad_output):
29
(result,) = ctx.saved_tensors
30
return grad_output * result
32
class Caller(torch.nn.Module):
33
def forward(self, input):
35
return SingleOut.apply(result) + 3
39
run_model_test(self, model, input_args=(input,))
41
def test_multi_output(self):
42
class MultiOut(torch.autograd.Function):
46
result_log = result_exp.log()
47
ctx.save_for_backward(result_exp, result_log)
48
return result_exp, result_log
51
def backward(ctx, grad_output):
52
(result,) = ctx.saved_tensors
53
return grad_output * result
55
class Caller(torch.nn.Module):
56
def forward(self, input):
57
return MultiOut.apply(input)
60
input = torch.ones(1, 5)
61
run_model_test(self, model, input_args=(input,))
63
def test_partial_output(self):
64
class PartialOut(torch.autograd.Function):
66
def forward(ctx, input):
67
ctx.save_for_backward(input)
68
values, indices = torch.topk(input, 3)
71
class Caller(torch.nn.Module):
72
def forward(self, input):
73
return PartialOut.apply(input)
76
input = torch.ones(1, 5)
77
run_model_test(self, model, input_args=(input,))
79
def test_nested_autograd(self):
80
class Child(torch.autograd.Function):
84
result_log = result.log()
85
ctx.save_for_backward(result_log)
89
def backward(ctx, grad_output):
90
(result,) = ctx.saved_tensors
91
return grad_output * result
93
class Parent(torch.autograd.Function):
97
result_log = Child.apply(result_exp)
98
ctx.save_for_backward(result_exp, result_log)
99
return result_exp, result_log
102
def backward(ctx, grad_output):
103
(result,) = ctx.saved_tensors
104
return grad_output * result
106
class Caller(torch.nn.Module):
107
def forward(self, input):
108
return Parent.apply(input)
111
input = torch.ones(1, 5)
112
run_model_test(self, model, input_args=(input,))
115
def test_aten_unsupported(self):
116
class Erf(torch.autograd.Function):
119
erf_out = torch.special.erf(x)
120
ctx.save_for_backward(erf_out)
124
def backward(ctx, grad_output):
125
result = ctx.saved_tensors
126
return torch.special.erfinv(result), None
128
class Caller(torch.nn.Module):
129
def forward(self, input):
130
return Erf.apply(input)
133
input = torch.ones(1, 5)
136
graph, _, _ = _model_to_graph(
139
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
142
self.assertEqual(next(iter).kind(), "prim::PythonOp")
145
graph, _, _ = _model_to_graph(
148
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
151
self.assertEqual(next(iter).kind(), "aten::ATen")
153
def test_inline_and_symbolic(self):
154
class Exp(torch.autograd.Function):
157
ctx.save_for_backward(input)
161
def symbolic(g, input):
162
return g.op("Exp", input)
164
class LogLog(torch.autograd.Function):
167
ctx.save_for_backward(input)
170
class Caller(torch.nn.Module):
171
def forward(self, input):
172
exp_result = Exp.apply(input)
173
return LogLog.apply(exp_result)
176
input = torch.ones(1)
177
run_model_test(self, model, input_args=(input,))
179
def test_inline_with_scoped_tracing(self):
180
class Exp(torch.autograd.Function):
183
ctx.save_for_backward(input)
187
def symbolic(g, input):
188
return g.op("Exp", input)
190
class LogLog(torch.autograd.Function):
193
ctx.save_for_backward(input)
196
class Caller(torch.nn.Module):
197
def forward(self, input):
198
exp_result = Exp.apply(input)
199
return LogLog.apply(exp_result)
202
input = torch.ones(1)
204
torch.jit._trace._trace_module_map = {
205
_m: torch.typename(type(_m)) for _m in model.modules()
207
run_model_test(self, model, input_args=(input,))
208
torch.jit._trace._trace_module_map = None
211
if __name__ == "__main__":
212
common_utils.run_tests()