7
from collections import defaultdict
8
from functools import partial
10
import torch._inductor.decomposition
12
from torch import Tensor
13
from torch._decomp import core_aten_decompositions, decomposition_table
14
from torch._dispatch.python import enable_python_dispatcher
15
from torch._ops import DispatchKey
16
from torch.testing import make_tensor
17
from torch.testing._internal.common_cuda import tf32_off
18
from torch.testing._internal.common_device_type import (
19
instantiate_device_type_tests,
22
onlyNativeDeviceTypes,
25
from torch.testing._internal.common_methods_invocations import (
31
from torch.testing._internal.common_modules import module_db, modules
32
from torch.testing._internal.common_utils import (
33
is_iterable_of_tensors,
41
unMarkDynamoStrictTest,
43
from torch.utils import _pytree as pytree
44
from torch.utils._python_dispatch import TorchDispatchMode
45
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
52
def overload_to_aten_name(op):
53
return op._schema.name.split("::")[1]
57
decomposition_names = {
58
overload_to_aten_name(k)
59
for k in decomposition_table
60
if isinstance(k, torch._ops.OpOverload)
62
core_decomposition_names = {
63
overload_to_aten_name(k)
64
for k in core_aten_decompositions()
65
if isinstance(k, torch._ops.OpOverload)
70
if op.aten_name in decomposition_names
71
or op.aten_backward_name in decomposition_names
73
_decomp_test_ops_core_autograd = [
76
if op.aten_name in core_decomposition_names and op.supports_autograd
78
_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name]
81
def diff_arg(arg, requires_grad=True):
82
def is_differentiable_arg(arg):
84
return arg.requires_grad
86
return arg.is_floating_point() or arg.is_complex()
88
if is_iterable_of_tensors(arg):
89
if all(is_differentiable_arg(a) for a in arg):
91
if all(not is_differentiable_arg(a) for a in arg):
93
raise RuntimeError("NYI: The test runner can't handle this")
94
return isinstance(arg, Tensor) and is_differentiable_arg(arg)
103
outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
105
inputs, inputs_spec = tree_flatten(inputs)
106
diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
107
if grad_outputs is None:
108
diff_outputs = tuple(out for out in outputs if out.requires_grad)
110
diff_grad_outputs = [
111
(out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
113
if len(diff_grad_outputs) == 0:
114
diff_outputs, grad_outputs = (), ()
116
diff_outputs, grad_outputs = zip(*diff_grad_outputs)
117
grad_inputs = torch.autograd.grad(
121
retain_graph=retain_graph,
122
create_graph=create_graph,
126
grad_inputs_iter = iter(grad_inputs)
128
if inp.requires_grad:
129
grad_input = next(grad_inputs_iter)
130
if grad_input is None:
131
result.append(torch.zeros_like(inp))
133
result.append(grad_input)
135
result.append(torch.zeros_like(inp))
136
return tree_unflatten(result, inputs_spec)
140
if isinstance(val, tuple):
145
def ref_vjp_no_create(f, *primals):
148
def wrapped(cotangents):
149
return _autograd_grad(
152
_as_tuple(cotangents),
157
return result, wrapped
161
torch.float16: (0.001, 1e-5),
162
torch.bfloat16: (0.016, 1e-4),
163
torch.float32: (1.3e-6, 1e-5),
164
torch.float64: (1e-7, 1e-7),
165
torch.complex32: (0.001, 1e-5),
166
torch.complex64: (1.3e-6, 1e-5),
167
torch.complex128: (1e-7, 1e-7),
173
def _getDefaultRtolAndAtol(dtype0, dtype1):
175
dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
178
dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
183
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
184
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
185
if orig.numel() == 0 or decomp.numel() == 0:
186
assert orig.numel() == decomp.numel()
188
assert orig.shape == decomp.shape, f"{i} Operation: {op}"
190
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
191
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
192
(torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3,
193
(torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2,
194
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
195
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
196
(torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
197
(torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
198
(torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
199
(torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
200
(torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
201
(torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
202
(torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7,
203
(torch.float16, torch.ops.aten.var_mean.correction): 5e-7,
204
(torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7,
205
(torch.float16, torch.ops.aten.var_mean.dim): 5e-7,
206
(torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
207
(torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
208
(torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
209
(torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
210
(torch.float16, torch.ops.aten.hardswish.default): 2e-7,
211
(torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
212
(torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
213
(torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2,
214
(torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
215
(torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
216
(torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
217
(torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
218
(torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
219
(torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
220
(torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
221
(torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
223
(torch.float16, torch.ops.aten.mv.default): 1e-5,
224
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
225
(torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
226
(torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
228
if ref.is_floating_point():
229
orig_diff = (orig - ref).abs().max()
230
decomp_diff = (decomp - ref).abs().max()
231
atol = tol_table.get((test_dtype, op), 1e-7)
232
if decomp_diff > orig_diff + atol:
234
f"Difference from float64 is larger with decomposition {op.__name__}"
235
f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
241
test_case.assertEqual(
242
orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
246
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
247
test_case.assertEqual(
250
f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}",
255
(torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
256
(torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
260
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
262
(torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5),
264
(torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5),
265
(torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
266
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4),
267
(torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4),
272
(torch.int8, torch.ops.aten.linspace.default): (0, 1),
273
(torch.uint8, torch.ops.aten.linspace.default): (0, 1),
274
(torch.int16, torch.ops.aten.linspace.default): (0, 1),
275
(torch.int32, torch.ops.aten.linspace.default): (0, 1),
276
(torch.int64, torch.ops.aten.linspace.default): (0, 1),
277
(torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
278
(torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
279
(torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
280
(torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
281
(torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
282
(torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
283
(torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
284
(torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
285
(torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
286
(torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
287
(torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
288
(torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
289
(torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
290
(torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
291
(torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
293
if (decomp.dtype, op) in tol_table:
294
rtol, atol = tol_table[(decomp.dtype, op)]
296
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
297
test_case.assertEqual(
302
msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}",
310
def normalize_op_input_output2(
311
f, args, kwargs, output_process_fn_grad=None, requires_grad=True
313
flat_args, args_spec = tree_flatten(args)
314
diff_argnums = tuple(
316
for i, arg in enumerate(flat_args)
317
if diff_arg(arg, requires_grad=requires_grad)
319
assert len(diff_argnums) > 0
320
primals = tuple(flat_args[i] for i in diff_argnums)
323
def wrapped(*primals):
324
_args = list(flat_args)
325
for num, arg in zip(diff_argnums, primals):
327
_args = tree_unflatten(_args, args_spec)
328
result = f(*_args, **kwargs)
329
if output_process_fn_grad is not None:
330
result = output_process_fn_grad(result)
331
if isinstance(result, tuple):
336
if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
338
assert len(result) > 0
341
return wrapped, primals
346
def upcast_tensor(x, dtype=torch.float32):
347
if isinstance(x, Tensor) and x.dtype.is_floating_point:
348
return x.to(dtype=dtype)
349
elif isinstance(x, torch.dtype) and x in [
359
def normalize_op_input_output(f, sample, requires_grad=True):
360
args = tuple([sample.input] + list(sample.args))
361
return normalize_op_input_output2(
365
sample.output_process_fn_grad,
366
requires_grad=requires_grad,
370
CROSS_REF_EXCLUDE_SET = {
376
("cuda", torch.bfloat16, "nn.functional.bilinear"),
378
(None, None, "special.ndtr"),
379
(None, None, "new_empty"),
380
(None, None, "empty_like"),
381
(None, None, "empty"),
383
(None, None, "item"),
386
(None, None, "zero_"),
390
(None, torch.float32, "masked.logsumexp"),
391
(None, torch.float64, "masked.logsumexp"),
393
(torch.cpu, torch.float16, "signal.windows.exponential"),
394
(torch.cpu, torch.float16, "signal.windows.gaussian"),
396
(torch.cpu, torch.float16, "signal.windows.cosine"),
399
(None, None, "nn.functional.relu6"),
401
(None, None, "nn.functional.rrelu"),
402
(None, None, "meshgrid"),
404
(None, None, "nn.functional.hardshrink"),
405
(None, None, "nn.functional.softshrink"),
407
(None, None, "diag"),
409
("cpu", torch.bfloat16, "_softmax_backward_data"),
410
(None, None, "norm"),
412
(None, None, "native_batch_norm"),
413
(None, None, "_upsample_bilinear2d_aa"),
414
(None, None, "empty_strided"),
417
CROSS_REF_BACKWARD_EXCLUDE_SET = {
419
("cpu", torch.bfloat16, "nn.functional.hardswish"),
420
("cuda", torch.float16, "nn.functional.cross_entropy"),
423
all_decomposed = set()
424
all_called = defaultdict(int)
430
print("missing coverage:")
431
print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
432
atexit.register(check_coverage)
439
with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
440
for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
441
f.write(f'{op.__name__}\n')
442
g.write(f'{count}\n')
443
with open('run_decompositions.txt', 'w') as f:
444
for op in sorted([i.__name__ for i in all_decomposed]):
447
atexit.register(dump_ops)
451
def any_unsupported(args, kwargs):
452
def test_unsupported(t):
453
if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
463
torch._is_functional_tensor(t),
466
elif torch.overrides.is_tensor_like(t):
473
flat_args = pytree.arg_tree_leaves(*args, **kwargs)
474
return any(test_unsupported(x) for x in flat_args)
477
core_backward_failures = {
478
skip("_softmax_backward_data"),
484
skip("grid_sampler_2d"),
487
skip("native_dropout_backward"),
488
xfail("nn.functional.binary_cross_entropy_with_logits"),
489
skip("nn.functional.glu"),
490
xfail("nn.functional.hardshrink"),
491
xfail("nn.functional.softshrink"),
492
skip("nn.functional.unfold"),
494
xfail("norm", "fro"),
495
xfail("norm", "inf"),
496
xfail("norm", "nuc"),
502
skip("special.xlog1py"),
510
if not TEST_WITH_SLOW:
511
core_backward_failures.update(
518
skip("nn.functional.hardswish"),
520
skip("split", variant_name="list_args"),
523
skip("unsafe_split"),
527
comprehensive_failures = {
529
"nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)
532
"nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)
535
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
540
@unMarkDynamoStrictTest
541
class TestDecomp(TestCase):
547
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
548
@onlyNativeDeviceTypes
551
@ops(_decomp_test_ops)
552
def test_quick(self, device, dtype, op):
553
self.do_cross_ref(device, dtype, op, run_all=False)
555
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
556
@skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
557
@onlyNativeDeviceTypes
560
@ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,))
561
def test_quick_core_backward(self, device, dtype, op):
562
for sample_input in op.sample_inputs(device, dtype, requires_grad=True):
563
aten_name = op.decomp_aten_name or op.aten_name
564
args = [sample_input.input] + list(sample_input.args)
565
kwargs = sample_input.kwargs
566
func = partial(op.get_op(), **kwargs)
567
with self.DecompCrossRefMode(
568
self, self.precision, self.rel_tol, dtype, run_all=False
569
) as mode, enable_python_dispatcher():
570
torch.autograd.gradcheck(func, args)
571
self.check_decomposed(aten_name, mode)
573
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
574
@onlyNativeDeviceTypes
576
@skipOps("TestDecomp", "test_comprehensive", comprehensive_failures)
579
def test_comprehensive(self, device, dtype, op):
580
self.do_cross_ref(device, dtype, op, run_all=True)
582
def test_uniform(self, device):
584
dtype = torch.float32
585
x = make_tensor(size, dtype=dtype, device=device)
589
torch.manual_seed(123)
590
ref = torch.ops.aten.uniform(x, low, high)
591
torch.manual_seed(123)
592
res = torch._decomp.decompositions.uniform(x, low=low, high=high)
593
self.assertEqual(ref, res)
595
def test_broadcasting_index_copy(self, device):
596
x = torch.zeros([1, 10], device=device)
597
xs = torch.ones([2, 10], device=device)
599
def index_copy(xs, x):
600
torch._decomp.decompositions.index_copy_(
601
xs, 0, torch.tensor(0).to(device), x
606
xs_two = torch.ones([2, 10], device=device)
609
self.assertEqual(xs, xs_two)
611
def test_cat_single_input(self, device):
612
decomp_table = torch._inductor.decomposition.select_decomp_table()
613
cat_inductor = decomp_table[torch.ops.aten.cat.default]
615
inp = torch.rand([2048, 2048], device=device)
616
inps = [inp for _ in range(10)]
618
for dim in (-1, 0, 1):
619
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
621
def test_rrelu_with_noise(self, device):
624
dtype = torch.float64
625
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
630
torch.manual_seed(123)
631
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
632
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
634
torch.manual_seed(123)
635
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
636
res = torch._decomp.decompositions.rrelu_with_noise(
643
self.assertEqual(ref, res)
644
self.assertEqual(noise_ref, noise_res)
649
torch.manual_seed(123)
650
noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
651
ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
653
torch.manual_seed(123)
654
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
655
res = torch._decomp.decompositions.rrelu_with_noise(
662
self.assertEqual(ref, res)
663
self.assertEqual(noise_ref, noise_res)
665
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
671
lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
675
def test_rnn_decomp_module(self, device, dtype, module_info, training):
676
module_cls = module_info.module_cls
677
module_inputs = module_info.module_inputs_func(
684
for module_input in module_inputs:
685
if module_input.forward_input is None:
688
module_input.constructor_input.args,
689
module_input.constructor_input.kwargs,
691
m = module_cls(*args, **kwargs)
692
m.to(device).to(dtype)
695
module_input.forward_input.args,
696
module_input.forward_input.kwargs,
698
with self.DecompCrossRefMode(
699
self, self.precision, self.rel_tol, dtype, run_all=True
700
), enable_python_dispatcher():
701
decomp_out = m(*args, **kwargs)
703
non_decomp_out = m(*args, **kwargs)
706
self.assertEqual(decomp_out, non_decomp_out)
708
def test_batch_norm_unflatten_weight_bias(self, device):
711
input = torch.randn(shape, device=device)
712
weight = torch.randn((3, 1, 1, 1), device=device)
713
bias = torch.randn(3, device=device)
714
mean = torch.randn(3, device=device)
715
var = torch.randn(3, device=device)
716
res = torch._decomp.decompositions.native_batch_norm(
717
input, weight, bias, mean, var, False, 1, 1e-05
719
self.assertEqual(shape, res[0].shape)
721
def test_arange_graph(self, device):
722
from torch.fx.experimental.proxy_tensor import make_fx
727
a = torch.arange(le, dtype=torch.float32, device=x.device)
729
a = torch.arange(start, le, dtype=torch.float32, device=x.device)
732
pattern = r", device = device\(.+\), requires_grad = False"
734
cfunc = make_fx(func, decomposition_table=decomposition_table)
735
fx_g = cfunc(torch.rand(10, device=device), None)
736
fx_g_code = fx_g.code.strip()
738
fx_g_code = re.sub(pattern, "", fx_g_code)
739
self.assertExpectedInline(
742
def forward(self, x_1, start_1):
743
iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
744
mul = torch.ops.prims.mul.default(iota, 1); iota = None
745
add = torch.ops.prims.add.default(mul, 0); mul = None
746
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
747
return convert_element_type""",
750
fx_g = cfunc(torch.rand(10, device=device), 1)
751
fx_g_code = fx_g.code.strip()
753
fx_g_code = re.sub(pattern, "", fx_g_code)
754
self.assertExpectedInline(
757
def forward(self, x_1, start_1):
758
iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
759
mul = torch.ops.prims.mul.default(iota, 1); iota = None
760
add = torch.ops.prims.add.default(mul, 1); mul = None
761
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
762
return convert_element_type""",
765
def test_masked_fill(self, device):
766
from torch.fx.experimental.proxy_tensor import make_fx
768
if torch.device(device).type not in [
771
torch._C._get_privateuse1_backend_name(),
773
self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
775
def func(scores, mask, value):
776
return scores.masked_fill(mask, value)
778
scores_t = torch.tensor([1, 2, 3, 4], device=device)
779
mask_t = torch.tensor([True, True, True, True], device=device)
780
value_t = torch.tensor(0, dtype=scores_t.dtype)
781
cfunc = make_fx(func, decomposition_table=decomposition_table)
782
fx_g = cfunc(scores_t, mask_t, value_t)
783
self.assertExpectedInline(
786
def forward(self, scores_1, mask_1, value_1):
787
where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None
791
class DecompCrossRefMode(TorchDispatchMode):
792
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
793
self.test_case = test_case
794
self.saved_precision = saved_precision
795
self.saved_rel_tol = saved_rel_tol
796
self.test_dtype = dtype
797
self.run_all = run_all
803
self.decomposed = set()
805
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
806
self.test_case.precision = self.saved_precision
807
self.test_case.rel_tol = self.saved_rel_tol
809
self.called.add(func)
810
all_called[func] += 1
815
in_place = func.name()[-1] == "_"
817
torch.ops.aten.detach.default,
819
torch.ops.aten.empty.memory_format,
820
torch.ops.aten.empty_like.default,
821
torch.ops.aten.new_empty.default,
822
torch.ops.aten.empty_strided.default,
823
torch.ops.aten.new_empty_strided.default,
824
torch.ops.aten.randn.default,
825
torch.ops.aten.native_dropout.default,
828
func not in decomposition_table
829
or func in ignored_ops
830
or torch.Tag.nondeterministic_seeded in func.tags
831
or any_unsupported(args, kwargs)
834
return func(*args, **kwargs)
836
self.decomposed.add(func)
837
all_decomposed.add(func)
853
decomposition = decomposition_table[func]
855
do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16]
859
decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
861
decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
870
real_out_unflat = func(*args, **kwargs)
871
real_out = pytree.tree_leaves(real_out_unflat)
873
assert len(real_out) == len(decomp_out)
875
if do_relative_check:
876
upcast = partial(upcast_tensor, dtype=torch.float64)
877
real_out_double, _ = tree_flatten(
878
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
880
for i, (orig, decomp, ref) in enumerate(
881
zip(real_out, decomp_out, real_out_double)
883
if not isinstance(orig, torch.Tensor):
884
assert type(orig) == type(decomp)
885
assert orig == decomp
899
for orig, decomp in zip(real_out, decomp_out):
900
if not isinstance(orig, torch.Tensor):
901
assert type(orig) == type(decomp)
902
assert orig == decomp
914
return real_out_unflat
916
def check_decomposed(self, aten_name, mode):
918
any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
920
f"aten.{aten_name} was not decomposed, saw calls for: "
921
f"{', '.join(map(str, list(mode.called)))}. If your op is "
922
f"CompositeImplicitAutograd you should skip this test "
923
f"by updating CROSS_REF_EXCLUDE_SET."
927
@skipIfTorchDynamo("Test does not work with TorchDynamo")
928
def do_cross_ref(self, device, dtype, op, *, run_all):
930
(torch.device(device).type, dtype, op.name),
931
(None, dtype, op.name),
932
(None, None, op.name),
934
if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
935
self.skipTest(f"{op.name} in {dtype} not supported")
937
skip_decomp_vjp = any(
938
key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys
943
and dtype in op.supported_backward_dtypes(torch.device(device).type)
948
and not dtype == torch.complex32
950
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
952
aten_name = op.decomp_aten_name or op.aten_name
956
def run_without_python_dispatcher(mode):
958
isinstance(op, torch._ops.OpOverload)
959
and op.has_kernel_for_dispatch_key(
960
DispatchKey.CompositeImplicitAutograd
962
for op in mode.decomposed.union([func])
965
for sample_input in samples:
967
fn, primals = normalize_op_input_output(func, sample_input)
969
lambda x: x if isinstance(x, torch.Tensor) else x, primals
976
with self.DecompCrossRefMode(
977
self, self.precision, self.rel_tol, dtype, run_all
978
) as mode, enable_python_dispatcher():
979
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
980
if run_without_python_dispatcher(mode):
983
with self.DecompCrossRefMode(
984
self, self.precision, self.rel_tol, dtype, run_all
986
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
987
if aten_name in decomposition_names:
988
self.check_decomposed(aten_name, mode)
990
if not skip_decomp_vjp and (
991
op.aten_backward_name in decomposition_names or run_all
993
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
995
with self.DecompCrossRefMode(
996
self, self.precision, self.rel_tol, dtype, run_all
997
) as mode, enable_python_dispatcher():
998
decomp_vjp_fn(cotangents)
999
if run_without_python_dispatcher(mode):
1002
with self.DecompCrossRefMode(
1003
self, self.precision, self.rel_tol, dtype, run_all
1005
decomp_vjp_fn(cotangents)
1007
self.check_decomposed(op.aten_backward_name, mode)
1009
elif aten_name in decomposition_names or run_all:
1010
args = [sample_input.input] + list(sample_input.args)
1011
kwargs = sample_input.kwargs
1014
with self.DecompCrossRefMode(
1015
self, self.precision, self.rel_tol, dtype, run_all
1016
) as mode, enable_python_dispatcher():
1017
func(*args, **kwargs)
1019
if run_without_python_dispatcher(mode):
1022
with self.DecompCrossRefMode(
1023
self, self.precision, self.rel_tol, dtype, run_all
1025
func(*args, **kwargs)
1028
self.check_decomposed(aten_name, mode)
1030
assert op.supports_autograd
1032
"only backwards is decomposed, but dtype doesn't support AD"
1036
instantiate_device_type_tests(TestDecomp, globals())
1039
class DecompOneOffTests(TestCase):
1040
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1041
@onlyNativeDeviceTypes
1043
def test_contiguous_softmax(self, device):
1045
stride = (9, 18, 3, 1)
1046
dtype = torch.float32
1048
x = torch.randn(size, dtype=dtype, device=device)
1049
x = torch.as_strided(x, size, stride)
1051
ref = torch.ops.aten._softmax(x, -1, False)
1052
res = torch._decomp.decompositions._softmax(x, -1, False)
1053
self.assertEqual(ref.stride(), res.stride())
1055
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1056
@onlyNativeDeviceTypes
1058
def test_contiguous_log_softmax(self, device):
1060
stride = (9, 18, 3, 1)
1062
dtype = torch.float32
1063
x = torch.randn(size, dtype=dtype, device=device)
1064
x = torch.as_strided(x, size, stride)
1066
ref = torch.ops.aten._log_softmax(x, -1, False)
1067
res = torch._decomp.decompositions._log_softmax(x, -1, False)
1068
self.assertEqual(ref.stride(), res.stride())
1071
def test_exponential_non_inf(self, device):
1072
inp = torch.empty((4, 400, 256), device=device)
1074
with torch._dynamo.utils.preserve_rng_state():
1075
exp_ref = inp.exponential_()
1076
exp = torch._refs.exponential(inp)
1078
self.assertEqual(exp, exp_ref)
1079
self.assertFalse(exp.isinf().any())
1081
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1084
def test_amp_batch_norm_backward(self):
1086
grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1087
x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1088
weight = torch.randn((2,), dtype=torch.float32, device=device)
1089
rmean = torch.randn((2,), dtype=torch.float32, device=device)
1090
rvar = torch.randn((2,), dtype=torch.float32, device=device)
1091
mean = torch.randn((0,), dtype=torch.float32, device=device)
1093
ref = torch.ops.aten.native_batch_norm_backward(
1105
res = torch._decomp.decompositions.native_batch_norm_backward(
1117
for a, b in zip(ref, res):
1118
self.assertEqual(a.stride(), b.stride())
1119
self.assertEqual(a.dtype, b.dtype)
1121
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1122
@onlyNativeDeviceTypes
1124
def test_elu_backward(self, device):
1126
dtype = torch.float32
1127
grad_out = torch.randn(size, dtype=dtype, device=device)
1128
out = torch.randn(size, dtype=dtype, device=device)
1130
ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out)
1131
res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out)
1132
self.assertEqual(ref, res)
1134
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1135
@onlyNativeDeviceTypes
1137
def test_threshold_backward_dtype(self, device):
1138
grad = torch.randint(10, (4,), device=device)
1139
input_tensor = torch.randint(10, (4,), device=device)
1141
ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1)
1142
res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1)
1143
self.assertEqual(ref.dtype, res.dtype)
1145
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1146
@onlyNativeDeviceTypes
1148
def test_weight_norm_interface(self, device):
1149
g = torch.randn((3, 10, 10), device=device)
1150
v = torch.randn((1, 1, 10), device=device)
1152
ref = torch.ops.aten._weight_norm_interface(g, v, 2)
1153
res = torch._decomp.decompositions._weight_norm_interface(g, v, 2)
1154
self.assertTrue(torch.allclose(ref[0], res[0]))
1155
self.assertTrue(torch.allclose(ref[1], res[1]))
1157
inp = torch.rand([30, 10], device=device)
1158
inp2 = torch.rand([30, 1], device=device)
1161
torch.ops.aten._weight_norm_interface(inp, inp2),
1162
torch._decomp.decompositions._weight_norm_interface(inp, inp2),
1165
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1169
"DecompOneOffTests",
1173
"nn.functional.scaled_dot_product_attention",
1174
dtypes=[torch.half],
1179
def test_sdpa(self, device, dtype, op):
1183
class ScaledDotProductAttention(torch.nn.Module):
1184
def __init__(self) -> None:
1188
self, query_layer, key_layer, value_layer, mask=None, is_causal=True
1196
is_causal=is_causal,
1200
query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1201
key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1202
value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1203
masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)]
1205
atol, rtol = dtype_precisions[dtype]
1208
is_causal = mask is None
1209
attention = ScaledDotProductAttention()
1211
torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
1212
query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
1221
is_causal=is_causal,
1225
torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
1229
instantiate_device_type_tests(DecompOneOffTests, globals())
1232
class HasDecompTest(TestCase):
1238
def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool:
1239
has_tensor_arg = any(
1240
"Tensor" in str(a.type)
1241
for a in itertools.chain(op._schema.arguments, op._schema.returns)
1243
if not has_tensor_arg:
1248
return not op.has_kernel_for_dispatch_key(
1249
DispatchKey.CompositeImplicitAutograd
1251
except RuntimeError as e:
1254
if "does not exist" in str(e):
1258
def test_has_decomposition(self):
1259
def all_aten_overloads():
1260
for name in torch._C._dispatch_get_all_op_names():
1261
if not name.startswith("aten::"):
1266
packet_name, overload_name = name.split(".")
1268
packet_name, overload_name = name, "default"
1270
packet = getattr(aten, packet_name)
1271
assert isinstance(packet, torch._ops.OpOverloadPacket)
1272
op = getattr(packet, overload_name)
1277
allow_list = {aten.get_gradients.default}
1279
overloads_wanting_decomp = {
1280
op for op in all_aten_overloads() if self._can_appear_in_trace(op)
1282
ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
1283
ops_missing_decomp -= allow_list
1284
self.assertExpected(
1285
"".join(sorted(op.name() + "\n" for op in ops_missing_decomp))
1288
def test_aten_core_operators(self):
1304
for op in decomposition_table.keys()
1305
if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op)
1307
core_decomps = torch._decomp.core_aten_decompositions().keys()
1308
core_aten_ops = useful_decomps - core_decomps
1309
self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))
1312
if __name__ == "__main__":