pytorch

Форк
0
/
test_autograd_funs.py 
212 строк · 6.6 Кб
1
# Owner(s): ["module: onnx"]
2

3
import pytorch_test_common
4

5
import torch
6
from onnx_test_common import run_model_test
7
from torch.onnx import OperatorExportTypes
8
from torch.onnx._globals import GLOBALS
9
from torch.onnx.utils import _model_to_graph
10
from torch.testing._internal import common_utils
11

12

13
class TestAutogradFuns(pytorch_test_common.ExportTestCase):
14
    opset_version = GLOBALS.export_onnx_opset_version
15
    keep_initializers_as_inputs = False
16
    onnx_shape_inference = True
17

18
    def test_single_output(self):
19
        class SingleOut(torch.autograd.Function):
20
            @staticmethod
21
            def forward(ctx, i):
22
                result = i.exp()
23
                result = result.log()
24
                ctx.save_for_backward(result)
25
                return result
26

27
            @staticmethod
28
            def backward(ctx, grad_output):
29
                (result,) = ctx.saved_tensors
30
                return grad_output * result
31

32
        class Caller(torch.nn.Module):
33
            def forward(self, input):
34
                result = input + 5
35
                return SingleOut.apply(result) + 3
36

37
        model = Caller()
38
        input = torch.ones(1)
39
        run_model_test(self, model, input_args=(input,))
40

41
    def test_multi_output(self):
42
        class MultiOut(torch.autograd.Function):
43
            @staticmethod
44
            def forward(ctx, i):
45
                result_exp = i.exp()
46
                result_log = result_exp.log()
47
                ctx.save_for_backward(result_exp, result_log)
48
                return result_exp, result_log
49

50
            @staticmethod
51
            def backward(ctx, grad_output):
52
                (result,) = ctx.saved_tensors
53
                return grad_output * result
54

55
        class Caller(torch.nn.Module):
56
            def forward(self, input):
57
                return MultiOut.apply(input)
58

59
        model = Caller()
60
        input = torch.ones(1, 5)
61
        run_model_test(self, model, input_args=(input,))
62

63
    def test_partial_output(self):
64
        class PartialOut(torch.autograd.Function):
65
            @staticmethod
66
            def forward(ctx, input):
67
                ctx.save_for_backward(input)
68
                values, indices = torch.topk(input, 3)
69
                return values
70

71
        class Caller(torch.nn.Module):
72
            def forward(self, input):
73
                return PartialOut.apply(input)
74

75
        model = Caller()
76
        input = torch.ones(1, 5)
77
        run_model_test(self, model, input_args=(input,))
78

79
    def test_nested_autograd(self):
80
        class Child(torch.autograd.Function):
81
            @staticmethod
82
            def forward(ctx, i):
83
                result = i.log()
84
                result_log = result.log()
85
                ctx.save_for_backward(result_log)
86
                return result_log
87

88
            @staticmethod
89
            def backward(ctx, grad_output):
90
                (result,) = ctx.saved_tensors
91
                return grad_output * result
92

93
        class Parent(torch.autograd.Function):
94
            @staticmethod
95
            def forward(ctx, i):
96
                result_exp = i.exp()
97
                result_log = Child.apply(result_exp)
98
                ctx.save_for_backward(result_exp, result_log)
99
                return result_exp, result_log
100

101
            @staticmethod
102
            def backward(ctx, grad_output):
103
                (result,) = ctx.saved_tensors
104
                return grad_output * result
105

106
        class Caller(torch.nn.Module):
107
            def forward(self, input):
108
                return Parent.apply(input)
109

110
        model = Caller()
111
        input = torch.ones(1, 5)
112
        run_model_test(self, model, input_args=(input,))
113

114
    # Run export in ONNX_FALLTHROUGH mode as torch.erf() is not supported
115
    def test_aten_unsupported(self):
116
        class Erf(torch.autograd.Function):
117
            @staticmethod
118
            def forward(ctx, x):
119
                erf_out = torch.special.erf(x)
120
                ctx.save_for_backward(erf_out)
121
                return erf_out
122

123
            @staticmethod
124
            def backward(ctx, grad_output):
125
                result = ctx.saved_tensors
126
                return torch.special.erfinv(result), None
127

128
        class Caller(torch.nn.Module):
129
            def forward(self, input):
130
                return Erf.apply(input)
131

132
        model = Caller()
133
        input = torch.ones(1, 5)
134

135
        # Test ONNX_FALLTHROUGH_MODE
136
        graph, _, _ = _model_to_graph(
137
            model,
138
            (input,),
139
            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
140
        )
141
        iter = graph.nodes()
142
        self.assertEqual(next(iter).kind(), "prim::PythonOp")
143

144
        # Test ATEN_FALLBACK_MODE
145
        graph, _, _ = _model_to_graph(
146
            model,
147
            (input,),
148
            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
149
        )
150
        iter = graph.nodes()
151
        self.assertEqual(next(iter).kind(), "aten::ATen")
152

153
    def test_inline_and_symbolic(self):
154
        class Exp(torch.autograd.Function):
155
            @staticmethod
156
            def forward(ctx, i):
157
                ctx.save_for_backward(input)
158
                return i.exp()
159

160
            @staticmethod
161
            def symbolic(g, input):
162
                return g.op("Exp", input)
163

164
        class LogLog(torch.autograd.Function):
165
            @staticmethod
166
            def forward(ctx, i):
167
                ctx.save_for_backward(input)
168
                return i.log().log()
169

170
        class Caller(torch.nn.Module):
171
            def forward(self, input):
172
                exp_result = Exp.apply(input)
173
                return LogLog.apply(exp_result)
174

175
        model = Caller()
176
        input = torch.ones(1)
177
        run_model_test(self, model, input_args=(input,))
178

179
    def test_inline_with_scoped_tracing(self):
180
        class Exp(torch.autograd.Function):
181
            @staticmethod
182
            def forward(ctx, i):
183
                ctx.save_for_backward(input)
184
                return i.exp()
185

186
            @staticmethod
187
            def symbolic(g, input):
188
                return g.op("Exp", input)
189

190
        class LogLog(torch.autograd.Function):
191
            @staticmethod
192
            def forward(ctx, i):
193
                ctx.save_for_backward(input)
194
                return i.log().log()
195

196
        class Caller(torch.nn.Module):
197
            def forward(self, input):
198
                exp_result = Exp.apply(input)
199
                return LogLog.apply(exp_result)
200

201
        model = Caller()
202
        input = torch.ones(1)
203

204
        torch.jit._trace._trace_module_map = {
205
            _m: torch.typename(type(_m)) for _m in model.modules()
206
        }
207
        run_model_test(self, model, input_args=(input,))
208
        torch.jit._trace._trace_module_map = None
209

210

211
if __name__ == "__main__":
212
    common_utils.run_tests()
213

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.