3
from functools import partial
4
from textwrap import dedent
8
from torch.testing import FileCheck
9
from torch.testing._internal.common_utils import \
10
(run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, TestCase)
11
from torch.testing._internal.common_methods_invocations import op_db
12
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
13
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
14
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, check_alias_annotation
15
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
16
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
20
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
21
allowed_dtypes=(torch.float, torch.cfloat))
30
@unMarkDynamoStrictTest
31
class TestJit(JitCommonTestCase):
39
def test_variant_consistency_jit(self, device, dtype, op):
40
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
42
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
43
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
47
method = op.get_method()
50
'function': func, 'method': method,
56
if isinstance(func, torch._ops.OpOverload):
57
self.skipTest("variant consistency doesn't work on torch.ops")
60
has_fake_function = op.name in ["resize_", 'resize_as_']
63
variants = {'method': getattr(torch.Tensor, op.name)}
64
samples = op.sample_inputs(device, dtype, requires_grad=False)
68
for sample in samples:
70
for func_type, variant in variants.items():
78
if is_lambda(variant):
83
self.indiv_variant_test_jit(device, dtype, op, sample, func_type, variant, has_fake_function)
84
except Exception as e:
85
variant_error_info = dedent(f"""
86
Error testing {op.name} {func_type} variant
90
raise Exception(variant_error_info) from e
92
assert tested, "JIT Test does not execute any logic"
94
def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function):
95
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
96
support_script = op.supports_scripting
98
name = op.name + '_' if func_type == 'inplace' else op.name
103
with disable_autodiff_subgraph_inlining():
106
script_fn = create_script_fn(self, name, func_type)
110
if sample.output_process_fn_grad is not None:
111
return sample.output_process_fn_grad(output)
115
return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input
118
check_against_reference(self,
122
(get_sample(),) + sample.args,
124
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
128
supports_tracing = op.supports_tracing and not has_fake_function
129
if op.assert_jit_shape_analysis:
130
self.assertTrue(supports_tracing)
133
traced_fn = create_traced_fn(self, variant)
134
check_against_reference(self,
138
(get_sample(),) + sample.args,
140
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
146
if dtype == torch.float32:
148
if support_script and op.name != "rsub":
149
check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
150
func_type=func_type, aten_name=op.aten_name)
153
checked_shape_analysis = False
155
out = variant(get_sample(), *sample.args, **sample.kwargs)
159
tuple_of_tensors = isinstance(out, tuple) and all(isinstance(elem, torch.Tensor) for elem in out)
161
if isinstance(out, torch.Tensor) or tuple_of_tensors:
163
sizes = [elem.size() for elem in out]
166
self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
167
checked_shape_analysis = True
168
if op.assert_jit_shape_analysis:
169
self.assertTrue(checked_shape_analysis)
172
if dtype is torch.float32:
176
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
179
nonfusible_nodes = op.autodiff_nonfusible_nodes
180
fusible_nodes = op.autodiff_fusible_nodes
183
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
185
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
188
_alias_ops = partial(ops, dtypes=OpDTypes.supported,
189
allowed_dtypes=(torch.float,))
191
@_alias_ops(op for op in op_db if op.aliases)
192
def test_jit_alias_remapping(self, device, dtype, op):
194
samples = op.sample_inputs(device, dtype, requires_grad=True)
195
sample = first_sample(self, samples)
207
if isinstance(v, str):
213
[f"{v}" for v in sample.args] + \
214
[f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
217
sample_args_kwargs = ()
218
if len(sample.args) > 0:
219
sample_args_kwargs += (sample.args, )
220
if len(sample.kwargs) > 0:
221
sample_args_kwargs += (sample.kwargs, )
223
original_name = op.aten_name
224
original_name_inplace = original_name + "_"
225
expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
227
for a_op in op.aliases:
228
inplace = a_op.inplace_variant
229
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
230
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
233
for variant in variants:
234
variant_name = variant.__name__
235
op_name = original_name_inplace if variant is inplace else original_name
237
if variant in method_or_inplace:
240
return t0.{alias_name}({args_kw})
243
script = fn_template.format(
244
c=", " if len(args_kw[1:]) > 1 else "",
245
args_kw=", ".join(args_kw[1:]),
246
alias_name=variant_name,
251
return variant({args_kw})
253
script = fn_template.format(
254
args=", ".join(args),
255
args_kw=", ".join(args_kw),
260
script = script.replace("tensor(", "torch.tensor(")
262
scripted = torch.jit.CompilationUnit(script)._fn
264
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
266
inp = clone_input_helper(sample.input)
268
except Exception as e:
270
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
272
inp = clone_input_helper(sample.input)
274
inp = clone_input_helper(sample.input)
275
graph = scripted.graph_for(inp)
276
FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
279
for variant in variants:
280
variant_name = variant.__name__
281
op_name = original_name_inplace if variant is inplace else original_name
283
def _fn(*sample_args, **sample_kwargs):
284
return variant(*sample_args, **sample_kwargs)
286
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
287
traced = torch.jit.trace(_fn, *inp)
288
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
290
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
291
graph = traced.graph_for(*inp)
292
FileCheck().check(op_name).check_not(variant_name).run(graph)
295
instantiate_device_type_tests(TestJit, globals())
297
if __name__ == '__main__':
298
TestCase._default_dtype_check_enabled = True