pytorch

Форк
0
/
test_ops_jit.py 
299 строк · 13.7 Кб
1
# Owner(s): ["module: unknown"]
2

3
from functools import partial
4
from textwrap import dedent
5

6
import torch
7

8
from torch.testing import FileCheck
9
from torch.testing._internal.common_utils import \
10
    (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, TestCase)
11
from torch.testing._internal.common_methods_invocations import op_db
12
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
13
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
14
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, check_alias_annotation
15
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
16
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
17

18
# variant testing is only done with torch.float and torch.cfloat to avoid
19
#   excessive test times and maximize signal to noise ratio
20
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
21
                       allowed_dtypes=(torch.float, torch.cfloat))
22

23

24

25
# Tests operators for consistency between JIT and eager, also checks
26
#   correctness of JIT specific alias schemas and intended
27
#   autodifferentiation behavior.
28
# Inherits from JitCommonTestCase instead of TestCase directly to share
29
#   functionality with original test_jit.py method operator tests
30
@unMarkDynamoStrictTest
31
class TestJit(JitCommonTestCase):
32
    exact_dtype = True
33

34
    # Tests that the forward and backward passes of operations produce the
35
    #   same values for the cross-product of op variants (function, method, inplace)
36
    #   and runtimes (eager, traced, scripted).
37
    # TODO WARNING: inplace x {traced, scripted} not currently tested
38
    @_variant_ops(op_db)
39
    def test_variant_consistency_jit(self, device, dtype, op):
40
        _requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
41

42
        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
43
        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
44

45
        # Acquires variants to test
46
        func = op.get_op()
47
        method = op.get_method()
48
        variants = {
49
            # TODO: inplace tests currently fail, fix and add inplace variant
50
            'function': func, 'method': method,
51
        }
52

53
        # scripting strips the torch.ops prefix from these operators
54
        # incorrectly; don't bother testing this case.  Count this
55
        # as "testing"
56
        if isinstance(func, torch._ops.OpOverload):
57
            self.skipTest("variant consistency doesn't work on torch.ops")
58

59
        # TODO: find better way to standardize on op registration itself..
60
        has_fake_function = op.name in ["resize_", 'resize_as_']
61

62
        if has_fake_function:
63
            variants = {'method': getattr(torch.Tensor, op.name)}
64
            samples = op.sample_inputs(device, dtype, requires_grad=False)
65

66

67
        tested = False
68
        for sample in samples:
69
            # Test traced and scripted consistency
70
            for func_type, variant in variants.items():
71
                if variant is None:
72
                    continue
73

74
                # scripting and check_alias_analysis do not work with lambdas
75
                # lambdas are typically used as a way to simulate methods without
76
                # functional variants, so rely on the other variant for testing
77
                # for now
78
                if is_lambda(variant):
79
                    continue
80

81
                tested = True
82
                try:
83
                    self.indiv_variant_test_jit(device, dtype, op, sample, func_type, variant, has_fake_function)
84
                except Exception as e:
85
                    variant_error_info = dedent(f"""
86
                        Error testing {op.name} {func_type} variant
87
                        with dtype: {dtype}
88
                        with inputs {sample}:
89
                    """)
90
                    raise Exception(variant_error_info) from e
91

92
        assert tested, "JIT Test does not execute any logic"
93

94
    def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function):
95
        _requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
96
        support_script = op.supports_scripting
97
        # Create accessor for script function variant
98
        name = op.name + '_' if func_type == 'inplace' else op.name
99

100
        # run with disable_autodiff_subgraph_inlining(True) to test
101
        #   autodiff support. Context manager forces the graph to contain
102
        #   DifferentiableGraph nodes if they are present
103
        with disable_autodiff_subgraph_inlining():
104
            # Check scripted forward, grad, and grad grad
105
            if support_script:
106
                script_fn = create_script_fn(self, name, func_type)
107

108
            def out_fn(output):
109
                # Processes the output for autograd
110
                if sample.output_process_fn_grad is not None:
111
                    return sample.output_process_fn_grad(output)
112
                return output
113

114
            def get_sample():
115
                return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input
116

117
            if support_script:
118
                check_against_reference(self,
119
                                        script_fn,
120
                                        op.get_op(),
121
                                        out_fn,
122
                                        (get_sample(),) + sample.args,
123
                                        sample.kwargs,
124
                                        no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
125

126
            # Check traced forward, grad, and grad grad
127
            # TODO: fix tracing here
128
            supports_tracing = op.supports_tracing and not has_fake_function
129
            if op.assert_jit_shape_analysis:
130
                self.assertTrue(supports_tracing)
131

132
            if supports_tracing:
133
                traced_fn = create_traced_fn(self, variant)
134
                check_against_reference(self,
135
                                        traced_fn,
136
                                        op.get_op(),
137
                                        out_fn,
138
                                        (get_sample(),) + sample.args,
139
                                        sample.kwargs,
140
                                        no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
141

142
            # Check alias annotation schema for correctness (make
143
            #   sure inputs that aren't supposed to be modified aren't)
144
            # Note: only runs in float32 because schema isn't affected by dtype,
145
            #   so running it on all dtypes is would be excessive
146
            if dtype == torch.float32:
147
                # TODO: no reason why we cant run this with tracing graph
148
                if support_script and op.name != "rsub":
149
                    check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
150
                                           func_type=func_type, aten_name=op.aten_name)
151

152
                # TODO: use script graph as well
153
                checked_shape_analysis = False
154
                if supports_tracing:
155
                    out = variant(get_sample(), *sample.args, **sample.kwargs)
156

157
                    # right now, tuple of outputs and tensor output supported
158
                    # TODO: list of tensor outputs
159
                    tuple_of_tensors = isinstance(out, tuple) and all(isinstance(elem, torch.Tensor) for elem in out)
160

161
                    if isinstance(out, torch.Tensor) or tuple_of_tensors:
162
                        if tuple_of_tensors:
163
                            sizes = [elem.size() for elem in out]
164
                        else:
165
                            sizes = out.size()
166
                        self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
167
                        checked_shape_analysis = True
168
                if op.assert_jit_shape_analysis:
169
                    self.assertTrue(checked_shape_analysis)
170

171
            # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
172
            if dtype is torch.float32:
173
                # Sandcastle doesn't fuse nodes
174
                if IS_SANDCASTLE:
175
                    # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
176
                    nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
177
                    fusible_nodes = []
178
                else:
179
                    nonfusible_nodes = op.autodiff_nonfusible_nodes
180
                    fusible_nodes = op.autodiff_fusible_nodes
181

182
                if supports_tracing:
183
                    self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
184
                if support_script:
185
                    self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
186

187
    # alias testing is only done with torch.float for the same reason
188
    _alias_ops = partial(ops, dtypes=OpDTypes.supported,
189
                         allowed_dtypes=(torch.float,))
190

191
    @_alias_ops(op for op in op_db if op.aliases)
192
    def test_jit_alias_remapping(self, device, dtype, op):
193
        # NOTE: only tests on first sample
194
        samples = op.sample_inputs(device, dtype, requires_grad=True)
195
        sample = first_sample(self, samples)
196

197
        # [Scripting Data Preparation]
198
        # Prepare data for test scripting
199
        # Below we prepare strings of args/kwargs with and without type annotations.
200
        # These strings are inserted into function template strings which is then torch scripted.
201
        # - args string is ["t0"] corresponding to the "input" tensor required by the op
202
        # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
203
        # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
204
        args = ["t0"]
205

206
        def quote_strs(v):
207
            if isinstance(v, str):
208
                return f"'{v}'"
209

210
            return str(v)
211

212
        args_kw = args + \
213
            [f"{v}" for v in sample.args] + \
214
            [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
215

216
        # Prepare data for test tracing
217
        sample_args_kwargs = ()
218
        if len(sample.args) > 0:
219
            sample_args_kwargs += (sample.args, )
220
        if len(sample.kwargs) > 0:
221
            sample_args_kwargs += (sample.kwargs, )
222

223
        original_name = op.aten_name
224
        original_name_inplace = original_name + "_"
225
        expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
226

227
        for a_op in op.aliases:
228
            inplace = a_op.inplace_variant
229
            method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
230
            variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
231

232
            # Test scripting:
233
            for variant in variants:
234
                variant_name = variant.__name__
235
                op_name = original_name_inplace if variant is inplace else original_name
236

237
                if variant in method_or_inplace:
238
                    fn_template = '''
239
                        def _fn(t0{c}):
240
                            return t0.{alias_name}({args_kw})
241
                    '''
242
                    # remove the first input tensor
243
                    script = fn_template.format(
244
                        c=", " if len(args_kw[1:]) > 1 else "",
245
                        args_kw=", ".join(args_kw[1:]),
246
                        alias_name=variant_name,
247
                    )
248
                else:
249
                    fn_template = '''
250
                        def _fn({args}):
251
                            return variant({args_kw})
252
                    '''
253
                    script = fn_template.format(
254
                        args=", ".join(args),
255
                        args_kw=", ".join(args_kw),
256
                    )
257

258
                # Required to avoid undefined value: tensor error in JIT
259
                # compilation of the function template
260
                script = script.replace("tensor(", "torch.tensor(")
261

262
                scripted = torch.jit.CompilationUnit(script)._fn
263

264
                if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
265
                    try:
266
                        inp = clone_input_helper(sample.input)
267
                        scripted(inp)
268
                    except Exception as e:
269
                        continue
270
                    self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
271

272
                inp = clone_input_helper(sample.input)
273
                scripted(inp)
274
                inp = clone_input_helper(sample.input)
275
                graph = scripted.graph_for(inp)
276
                FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
277

278
            # Test tracing:
279
            for variant in variants:
280
                variant_name = variant.__name__
281
                op_name = original_name_inplace if variant is inplace else original_name
282

283
                def _fn(*sample_args, **sample_kwargs):
284
                    return variant(*sample_args, **sample_kwargs)
285

286
                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
287
                traced = torch.jit.trace(_fn, *inp)
288
                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
289
                traced(*inp)
290
                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
291
                graph = traced.graph_for(*inp)
292
                FileCheck().check(op_name).check_not(variant_name).run(graph)
293

294

295
instantiate_device_type_tests(TestJit, globals())
296

297
if __name__ == '__main__':
298
    TestCase._default_dtype_check_enabled = True
299
    run_tests()
300

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.