pytorch

Форк
0
/
test_decomp.py 
1313 строк · 51.1 Кб
1
# Owner(s): ["module: decompositions"]
2

3
import functools
4
import itertools
5
import re
6
import unittest
7
from collections import defaultdict
8
from functools import partial
9

10
import torch._inductor.decomposition
11
import torch.autograd
12
from torch import Tensor
13
from torch._decomp import core_aten_decompositions, decomposition_table
14
from torch._dispatch.python import enable_python_dispatcher
15
from torch._ops import DispatchKey
16
from torch.testing import make_tensor
17
from torch.testing._internal.common_cuda import tf32_off
18
from torch.testing._internal.common_device_type import (
19
    instantiate_device_type_tests,
20
    onlyCPU,
21
    onlyCUDA,
22
    onlyNativeDeviceTypes,
23
    ops,
24
)
25
from torch.testing._internal.common_methods_invocations import (
26
    op_db,
27
    skip,
28
    skipOps,
29
    xfail,
30
)
31
from torch.testing._internal.common_modules import module_db, modules
32
from torch.testing._internal.common_utils import (
33
    is_iterable_of_tensors,
34
    run_tests,
35
    skipIfCrossRef,
36
    skipIfTorchDynamo,
37
    suppress_warnings,
38
    TEST_WITH_ASAN,
39
    TEST_WITH_SLOW,
40
    TestCase,
41
    unMarkDynamoStrictTest,
42
)
43
from torch.utils import _pytree as pytree
44
from torch.utils._python_dispatch import TorchDispatchMode
45
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
46

47

48
aten = torch.ops.aten
49

50

51
# TODO: this isn't going to work with non-aten namespaces
52
def overload_to_aten_name(op):
53
    return op._schema.name.split("::")[1]
54

55

56
# All operators that can have decomp tests
57
decomposition_names = {
58
    overload_to_aten_name(k)
59
    for k in decomposition_table
60
    if isinstance(k, torch._ops.OpOverload)
61
}
62
core_decomposition_names = {
63
    overload_to_aten_name(k)
64
    for k in core_aten_decompositions()
65
    if isinstance(k, torch._ops.OpOverload)
66
}
67
_decomp_test_ops = [
68
    op
69
    for op in op_db
70
    if op.aten_name in decomposition_names
71
    or op.aten_backward_name in decomposition_names
72
]
73
_decomp_test_ops_core_autograd = [
74
    op
75
    for op in op_db
76
    if op.aten_name in core_decomposition_names and op.supports_autograd
77
]
78
_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name]
79

80

81
def diff_arg(arg, requires_grad=True):
82
    def is_differentiable_arg(arg):
83
        if requires_grad:
84
            return arg.requires_grad
85
        else:
86
            return arg.is_floating_point() or arg.is_complex()
87

88
    if is_iterable_of_tensors(arg):
89
        if all(is_differentiable_arg(a) for a in arg):
90
            return True
91
        if all(not is_differentiable_arg(a) for a in arg):
92
            return False
93
        raise RuntimeError("NYI: The test runner can't handle this")
94
    return isinstance(arg, Tensor) and is_differentiable_arg(arg)
95

96

97
# Version of autograd.grad with some differences:
98
#   - pytree inputs is allowed (but leaves of the pytree have to all
99
#     be tensors)
100
#   - if an input is not used as part of derivatives, we will return a
101
#     zero-filled tensor for the result
102
def _autograd_grad(
103
    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
104
):
105
    inputs, inputs_spec = tree_flatten(inputs)
106
    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
107
    if grad_outputs is None:
108
        diff_outputs = tuple(out for out in outputs if out.requires_grad)
109
    else:
110
        diff_grad_outputs = [
111
            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
112
        ]
113
        if len(diff_grad_outputs) == 0:
114
            diff_outputs, grad_outputs = (), ()
115
        else:
116
            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
117
    grad_inputs = torch.autograd.grad(
118
        diff_outputs,
119
        diff_inputs,
120
        grad_outputs,
121
        retain_graph=retain_graph,
122
        create_graph=create_graph,
123
        allow_unused=True,
124
    )
125
    result = []
126
    grad_inputs_iter = iter(grad_inputs)
127
    for inp in inputs:
128
        if inp.requires_grad:
129
            grad_input = next(grad_inputs_iter)
130
            if grad_input is None:
131
                result.append(torch.zeros_like(inp))
132
            else:
133
                result.append(grad_input)
134
        else:
135
            result.append(torch.zeros_like(inp))
136
    return tree_unflatten(result, inputs_spec)
137

138

139
def _as_tuple(val):
140
    if isinstance(val, tuple):
141
        return val
142
    return (val,)
143

144

145
def ref_vjp_no_create(f, *primals):
146
    result = f(*primals)
147

148
    def wrapped(cotangents):
149
        return _autograd_grad(
150
            _as_tuple(result),
151
            primals,
152
            _as_tuple(cotangents),
153
            create_graph=False,
154
            retain_graph=True,
155
        )
156

157
    return result, wrapped
158

159

160
dtype_precisions = {
161
    torch.float16: (0.001, 1e-5),
162
    torch.bfloat16: (0.016, 1e-4),
163
    torch.float32: (1.3e-6, 1e-5),
164
    torch.float64: (1e-7, 1e-7),
165
    torch.complex32: (0.001, 1e-5),
166
    torch.complex64: (1.3e-6, 1e-5),
167
    torch.complex128: (1e-7, 1e-7),
168
}
169
# Returns the "default" rtol and atol for comparing scalars or
170
# tensors of the given dtypes.
171

172

173
def _getDefaultRtolAndAtol(dtype0, dtype1):
174
    rtol = max(
175
        dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
176
    )
177
    atol = max(
178
        dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
179
    )
180
    return rtol, atol
181

182

183
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
184
    assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}"
185
    if orig.numel() == 0 or decomp.numel() == 0:
186
        assert orig.numel() == decomp.numel()
187
        return
188
    assert orig.shape == decomp.shape, f"{i} Operation:  {op}"
189
    tol_table = {
190
        (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
191
        (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
192
        (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3,
193
        (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2,
194
        (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
195
        (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
196
        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
197
        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
198
        (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
199
        (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
200
        (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
201
        (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
202
        (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7,
203
        (torch.float16, torch.ops.aten.var_mean.correction): 5e-7,
204
        (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7,
205
        (torch.float16, torch.ops.aten.var_mean.dim): 5e-7,
206
        (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
207
        (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
208
        (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
209
        (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
210
        (torch.float16, torch.ops.aten.hardswish.default): 2e-7,
211
        (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
212
        (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
213
        (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2,
214
        (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
215
        (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
216
        (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
217
        (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
218
        (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
219
        (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
220
        (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
221
        (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
222
        # see https://github.com/pytorch/pytorch/pull/96264
223
        (torch.float16, torch.ops.aten.mv.default): 1e-5,
224
        (torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
225
        (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
226
        (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
227
    }
228
    if ref.is_floating_point():
229
        orig_diff = (orig - ref).abs().max()
230
        decomp_diff = (decomp - ref).abs().max()
231
        atol = tol_table.get((test_dtype, op), 1e-7)
232
        if decomp_diff > orig_diff + atol:
233
            raise RuntimeError(
234
                f"Difference from float64 is larger with decomposition {op.__name__}"
235
                f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
236
                f"atol = {atol}\n"
237
                f"args = {args}\n"
238
                f"kwargs = {kwargs}"
239
            )
240
    else:
241
        test_case.assertEqual(
242
            orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
243
        )
244

245

246
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
247
    test_case.assertEqual(
248
        orig.dtype,
249
        decomp.dtype,
250
        f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}",
251
    )
252
    # Before adding an entry to this table, make sure your decomposition is right :)
253
    tol_table = {
254
        # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
255
        (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
256
        (torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
257
            1e-3,
258
            1e-3,
259
        ),
260
        (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
261
        # This exceeds default tolerances only on CPU, on CUDA it's fine
262
        (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5),
263
        # Exceeds tolerances on CUDA, likely due to fma
264
        (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5),
265
        (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
266
        (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4),
267
        (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4),
268
        # The decomposition is TOO correct. It computes everything in int64, so sometimes
269
        # there's an off-by-one error. See
270
        # https://github.com/pytorch/pytorch/issues/81996
271
        # https://github.com/pytorch/pytorch/issues/82230
272
        (torch.int8, torch.ops.aten.linspace.default): (0, 1),
273
        (torch.uint8, torch.ops.aten.linspace.default): (0, 1),
274
        (torch.int16, torch.ops.aten.linspace.default): (0, 1),
275
        (torch.int32, torch.ops.aten.linspace.default): (0, 1),
276
        (torch.int64, torch.ops.aten.linspace.default): (0, 1),
277
        (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
278
        (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
279
        (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
280
        (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
281
        (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
282
        (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
283
        (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
284
        (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
285
        (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
286
        (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
287
        (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
288
        (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
289
        (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
290
        (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
291
        (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
292
    }
293
    if (decomp.dtype, op) in tol_table:
294
        rtol, atol = tol_table[(decomp.dtype, op)]
295
    else:
296
        rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
297
    test_case.assertEqual(
298
        orig,
299
        decomp,
300
        rtol=rtol,
301
        atol=atol,
302
        msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}",
303
    )
304

305

306
# Given f, returns an f' such that:
307
# - f' takes only positional arguments
308
# - All arguments to f' are floating-point Tensors
309
# - All outputs of f' are floating-point Tensors
310
def normalize_op_input_output2(
311
    f, args, kwargs, output_process_fn_grad=None, requires_grad=True
312
):
313
    flat_args, args_spec = tree_flatten(args)
314
    diff_argnums = tuple(
315
        i
316
        for i, arg in enumerate(flat_args)
317
        if diff_arg(arg, requires_grad=requires_grad)
318
    )
319
    assert len(diff_argnums) > 0
320
    primals = tuple(flat_args[i] for i in diff_argnums)
321

322
    @functools.wraps(f)
323
    def wrapped(*primals):
324
        _args = list(flat_args)
325
        for num, arg in zip(diff_argnums, primals):
326
            _args[num] = arg
327
        _args = tree_unflatten(_args, args_spec)
328
        result = f(*_args, **kwargs)
329
        if output_process_fn_grad is not None:
330
            result = output_process_fn_grad(result)
331
        if isinstance(result, tuple):
332
            # TODO We should check that the integer outputs also agree
333
            result = tuple(
334
                r
335
                for r in result
336
                if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
337
            )
338
            assert len(result) > 0
339
        return result
340

341
    return wrapped, primals
342

343

344
# NB: This also upcasts dtype arguments
345
# TODO: handle complex correctly
346
def upcast_tensor(x, dtype=torch.float32):
347
    if isinstance(x, Tensor) and x.dtype.is_floating_point:
348
        return x.to(dtype=dtype)
349
    elif isinstance(x, torch.dtype) and x in [
350
        torch.float16,
351
        torch.bfloat16,
352
        torch.float,
353
    ]:
354
        return dtype
355
    else:
356
        return x
357

358

359
def normalize_op_input_output(f, sample, requires_grad=True):
360
    args = tuple([sample.input] + list(sample.args))
361
    return normalize_op_input_output2(
362
        f,
363
        args,
364
        sample.kwargs,
365
        sample.output_process_fn_grad,
366
        requires_grad=requires_grad,
367
    )
368

369

370
CROSS_REF_EXCLUDE_SET = {
371
    # CUBLAS_STATUS_NOT_SUPPORTED when calling
372
    # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
373
    # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
374
    # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
375
    # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
376
    ("cuda", torch.bfloat16, "nn.functional.bilinear"),
377
    # randomness
378
    (None, None, "special.ndtr"),  # aten.special_ndtr was not decomposed
379
    (None, None, "new_empty"),
380
    (None, None, "empty_like"),
381
    (None, None, "empty"),
382
    # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
383
    (None, None, "item"),
384
    # It's the only in-place op without an out-of-place equivalent in the Python API
385
    # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
386
    (None, None, "zero_"),
387
    # No idea what's going on here
388
    # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
389
    # in the test, but it seems to pass when tested locally and in the logsumexp test
390
    (None, torch.float32, "masked.logsumexp"),
391
    (None, torch.float64, "masked.logsumexp"),
392
    # exp_vml_cpu not implemented for Half
393
    (torch.cpu, torch.float16, "signal.windows.exponential"),
394
    (torch.cpu, torch.float16, "signal.windows.gaussian"),
395
    # sin_vml_cpu not implemented for Half
396
    (torch.cpu, torch.float16, "signal.windows.cosine"),
397
    # CompositeAutogradImplicit
398
    # See https://github.com/pytorch/pytorch/issues/81669
399
    (None, None, "nn.functional.relu6"),
400
    # This decomp runs before autograd.
401
    (None, None, "nn.functional.rrelu"),
402
    (None, None, "meshgrid"),
403
    # Decomposition registered as Autograd
404
    (None, None, "nn.functional.hardshrink"),
405
    (None, None, "nn.functional.softshrink"),
406
    # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit)
407
    (None, None, "diag"),
408
    # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32
409
    ("cpu", torch.bfloat16, "_softmax_backward_data"),
410
    (None, None, "norm"),
411
    # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
412
    (None, None, "native_batch_norm"),
413
    (None, None, "_upsample_bilinear2d_aa"),
414
    (None, None, "empty_strided"),  # aten.empty_strided was not decomposed
415
}
416

417
CROSS_REF_BACKWARD_EXCLUDE_SET = {
418
    # Decomposed backward formula is not as precise
419
    ("cpu", torch.bfloat16, "nn.functional.hardswish"),
420
    ("cuda", torch.float16, "nn.functional.cross_entropy"),
421
}
422

423
all_decomposed = set()
424
all_called = defaultdict(int)
425

426
# Helpful snippet for testing coverage
427
"""
428
import atexit
429
def check_coverage():
430
    print("missing coverage:")
431
    print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
432
atexit.register(check_coverage)
433
"""
434

435
# Helpful snippet for Horace to create his google sheet :)
436
"""
437
import atexit
438
def dump_ops():
439
    with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
440
        for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
441
            f.write(f'{op.__name__}\n')
442
            g.write(f'{count}\n')
443
    with open('run_decompositions.txt', 'w') as f:
444
        for op in sorted([i.__name__ for i in all_decomposed]):
445
            f.write(f'{op}\n')
446

447
atexit.register(dump_ops)
448
"""
449

450

451
def any_unsupported(args, kwargs):
452
    def test_unsupported(t):
453
        if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
454
            # These are all things that we haven't coded decompositions
455
            # to handle correctly.  Maybe they should.
456
            return any(
457
                [
458
                    t.is_sparse_csr,
459
                    t.is_sparse,
460
                    t.is_mkldnn,
461
                    t.is_quantized,
462
                    t.is_nested,
463
                    torch._is_functional_tensor(t),
464
                ]
465
            )
466
        elif torch.overrides.is_tensor_like(t):
467
            # Decompositions will generally change the behavior of Tensor-like
468
            # subclasses, so bypass tests in this case too
469
            return True
470
        else:
471
            return False
472

473
    flat_args = pytree.arg_tree_leaves(*args, **kwargs)
474
    return any(test_unsupported(x) for x in flat_args)
475

476

477
core_backward_failures = {
478
    skip("_softmax_backward_data"),  # slow: fails with --timeout=360 secs
479
    xfail("addcdiv"),
480
    skip("addcmul"),  # slow: fails with --timeout=360 secs
481
    skip("deg2rad"),  # slow: fails with --timeout=360 secs
482
    skip("diag_embed"),  # slow: fails with --timeout=360 secs
483
    skip("frac"),  # slow: fails with --timeout=360 secs
484
    skip("grid_sampler_2d"),  # slow: fails with --timeout=360 secs
485
    xfail("lerp"),
486
    skip("logaddexp"),  # slow: fails with --timeout=360 secs
487
    skip("native_dropout_backward"),  # slow: fails with --timeout=360 secs
488
    xfail("nn.functional.binary_cross_entropy_with_logits"),
489
    skip("nn.functional.glu"),  # slow: fails with --timeout=360 secs
490
    xfail("nn.functional.hardshrink"),
491
    xfail("nn.functional.softshrink"),
492
    skip("nn.functional.unfold"),  # slow: fails with --timeout=360 secs
493
    xfail("norm"),
494
    xfail("norm", "fro"),
495
    xfail("norm", "inf"),
496
    xfail("norm", "nuc"),
497
    skip("rad2deg"),  # slow: fails with --timeout=360 secs
498
    skip("renorm"),  # slow: fails with --timeout=360 secs
499
    skip("rot90"),  # slow: fails with --timeout=360 secs
500
    skip("rsub"),  # slow: fails with --timeout=360 secs
501
    skip("sgn"),  # slow: fails with --timeout=360 secs
502
    skip("special.xlog1py"),  # slow: fails with --timeout=360 secs
503
    xfail("stack"),
504
    skip("tril"),  # slow: fails with --timeout=360 secs
505
    skip("triu"),  # slow: fails with --timeout=360 secs
506
    skip("unfold_copy"),  # slow: fails with --timeout=360 secs
507
    skip("xlogy"),  # slow: fails with --timeout=360 secs
508
    xfail("zero_"),
509
}
510
if not TEST_WITH_SLOW:
511
    core_backward_failures.update(
512
        {
513
            skip("addr"),  # slow: takes 46 sec on A100
514
            skip("baddbmm"),  # slow: takes 800+ sec on A100
515
            skip("clamp_min"),  # slow: takes 800 sec on A100
516
            skip("clamp_max"),  # slow: takes 800 sec on A100
517
            skip("logit"),  # slow: takes 44 sec on A100
518
            skip("nn.functional.hardswish"),  # slow: takes 60 sec on A100
519
            skip("std_mean"),  # slow: takes 170 sec on A100
520
            skip("split", variant_name="list_args"),  # slow: takes 118 sec on A100
521
            skip("transpose"),  # slow: takes 50 sec on A100
522
            skip("unbind"),  # slow: takes 70 sec on A100
523
            skip("unsafe_split"),  # slow: takes 49 sec on A100
524
        }
525
    )
526

527
comprehensive_failures = {
528
    xfail(
529
        "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)
530
    ),  # off by one error
531
    xfail(
532
        "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)
533
    ),  # off by one error
534
    xfail(
535
        "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
536
    ),  # off by one error
537
}
538

539

540
@unMarkDynamoStrictTest
541
class TestDecomp(TestCase):
542
    longMessage = True
543

544
    # NB: This actually overlaps with test_comprehensive, but it only
545
    # runs on things that are definitely decomposed so it's a lot faster
546
    # to run
547
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
548
    @onlyNativeDeviceTypes
549
    @skipIfCrossRef
550
    @suppress_warnings
551
    @ops(_decomp_test_ops)
552
    def test_quick(self, device, dtype, op):
553
        self.do_cross_ref(device, dtype, op, run_all=False)
554

555
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
556
    @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
557
    @onlyNativeDeviceTypes
558
    @skipIfCrossRef
559
    @suppress_warnings
560
    @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,))
561
    def test_quick_core_backward(self, device, dtype, op):
562
        for sample_input in op.sample_inputs(device, dtype, requires_grad=True):
563
            aten_name = op.decomp_aten_name or op.aten_name
564
            args = [sample_input.input] + list(sample_input.args)
565
            kwargs = sample_input.kwargs
566
            func = partial(op.get_op(), **kwargs)
567
            with self.DecompCrossRefMode(
568
                self, self.precision, self.rel_tol, dtype, run_all=False
569
            ) as mode, enable_python_dispatcher():
570
                torch.autograd.gradcheck(func, args)
571
            self.check_decomposed(aten_name, mode)
572

573
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
574
    @onlyNativeDeviceTypes
575
    @skipIfCrossRef
576
    @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures)
577
    @suppress_warnings
578
    @ops(op_db)
579
    def test_comprehensive(self, device, dtype, op):
580
        self.do_cross_ref(device, dtype, op, run_all=True)
581

582
    def test_uniform(self, device):
583
        size = (2, 3, 4, 5)
584
        dtype = torch.float32
585
        x = make_tensor(size, dtype=dtype, device=device)
586
        low = 0.3
587
        high = 0.9
588

589
        torch.manual_seed(123)
590
        ref = torch.ops.aten.uniform(x, low, high)
591
        torch.manual_seed(123)
592
        res = torch._decomp.decompositions.uniform(x, low=low, high=high)
593
        self.assertEqual(ref, res)
594

595
    def test_broadcasting_index_copy(self, device):
596
        x = torch.zeros([1, 10], device=device)
597
        xs = torch.ones([2, 10], device=device)
598

599
        def index_copy(xs, x):
600
            torch._decomp.decompositions.index_copy_(
601
                xs, 0, torch.tensor(0).to(device), x
602
            )
603

604
        index_copy(xs, x)
605

606
        xs_two = torch.ones([2, 10], device=device)
607
        xs_two[0] = x
608

609
        self.assertEqual(xs, xs_two)
610

611
    def test_cat_single_input(self, device):
612
        decomp_table = torch._inductor.decomposition.select_decomp_table()
613
        cat_inductor = decomp_table[torch.ops.aten.cat.default]
614

615
        inp = torch.rand([2048, 2048], device=device)
616
        inps = [inp for _ in range(10)]
617

618
        for dim in (-1, 0, 1):
619
            self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
620

621
    def test_rrelu_with_noise(self, device):
622
        # rrelu_with_noise behavior depends on a) whether elements in the input
623
        # are <= 0, and b) whether we're in training mode. Cover all cases:
624
        dtype = torch.float64
625
        x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
626
        lower = 1.0
627
        upper = 4.0
628
        training = False
629

630
        torch.manual_seed(123)
631
        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
632
        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
633

634
        torch.manual_seed(123)
635
        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
636
        res = torch._decomp.decompositions.rrelu_with_noise(
637
            x,
638
            noise_res,
639
            lower,
640
            upper,
641
            training,
642
        )
643
        self.assertEqual(ref, res)
644
        self.assertEqual(noise_ref, noise_res)
645

646
        # Now with training=True:
647
        training = True
648

649
        torch.manual_seed(123)
650
        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
651
        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
652

653
        torch.manual_seed(123)
654
        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
655
        res = torch._decomp.decompositions.rrelu_with_noise(
656
            x,
657
            noise_res,
658
            lower,
659
            upper,
660
            training,
661
        )
662
        self.assertEqual(ref, res)
663
        self.assertEqual(noise_ref, noise_res)
664

665
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
666
    @suppress_warnings
667
    @tf32_off()
668
    # only tests RNNs since we have py dispsatcher decomps for them
669
    @modules(
670
        filter(
671
            lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
672
            module_db,
673
        )
674
    )
675
    def test_rnn_decomp_module(self, device, dtype, module_info, training):
676
        module_cls = module_info.module_cls
677
        module_inputs = module_info.module_inputs_func(
678
            module_info,
679
            device=device,
680
            dtype=dtype,
681
            requires_grad=True,
682
            training=training,
683
        )
684
        for module_input in module_inputs:
685
            if module_input.forward_input is None:
686
                continue
687
            args, kwargs = (
688
                module_input.constructor_input.args,
689
                module_input.constructor_input.kwargs,
690
            )
691
            m = module_cls(*args, **kwargs)
692
            m.to(device).to(dtype)
693

694
            args, kwargs = (
695
                module_input.forward_input.args,
696
                module_input.forward_input.kwargs,
697
            )
698
            with self.DecompCrossRefMode(
699
                self, self.precision, self.rel_tol, dtype, run_all=True
700
            ), enable_python_dispatcher():
701
                decomp_out = m(*args, **kwargs)
702

703
            non_decomp_out = m(*args, **kwargs)
704
            # without this check, incorrect decomps at the python dispatcher level can still pass because
705
            # they're checking aten decomps at the torch_dispatch level
706
            self.assertEqual(decomp_out, non_decomp_out)
707

708
    def test_batch_norm_unflatten_weight_bias(self, device):
709
        # https://github.com/pytorch/pytorch/issues/100970
710
        shape = (1, 3, 2, 2)
711
        input = torch.randn(shape, device=device)
712
        weight = torch.randn((3, 1, 1, 1), device=device)
713
        bias = torch.randn(3, device=device)
714
        mean = torch.randn(3, device=device)
715
        var = torch.randn(3, device=device)
716
        res = torch._decomp.decompositions.native_batch_norm(
717
            input, weight, bias, mean, var, False, 1, 1e-05
718
        )
719
        self.assertEqual(shape, res[0].shape)
720

721
    def test_arange_graph(self, device):
722
        from torch.fx.experimental.proxy_tensor import make_fx
723

724
        def func(x, start):
725
            le = x.shape[-1]
726
            if start is None:
727
                a = torch.arange(le, dtype=torch.float32, device=x.device)
728
            else:
729
                a = torch.arange(start, le, dtype=torch.float32, device=x.device)
730
            return a
731

732
        pattern = r", device = device\(.+\), requires_grad = False"
733

734
        cfunc = make_fx(func, decomposition_table=decomposition_table)
735
        fx_g = cfunc(torch.rand(10, device=device), None)
736
        fx_g_code = fx_g.code.strip()
737
        # Remove device and requires_grad
738
        fx_g_code = re.sub(pattern, "", fx_g_code)
739
        self.assertExpectedInline(
740
            fx_g_code,
741
            """\
742
def forward(self, x_1, start_1):
743
    iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
744
    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
745
    add = torch.ops.prims.add.default(mul, 0);  mul = None
746
    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
747
    return convert_element_type""",
748
        )
749

750
        fx_g = cfunc(torch.rand(10, device=device), 1)
751
        fx_g_code = fx_g.code.strip()
752
        # Remove device and requires_grad
753
        fx_g_code = re.sub(pattern, "", fx_g_code)
754
        self.assertExpectedInline(
755
            fx_g_code,
756
            """\
757
def forward(self, x_1, start_1):
758
    iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
759
    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
760
    add = torch.ops.prims.add.default(mul, 1);  mul = None
761
    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
762
    return convert_element_type""",
763
        )
764

765
    def test_masked_fill(self, device):
766
        from torch.fx.experimental.proxy_tensor import make_fx
767

768
        if torch.device(device).type not in [
769
            "xpu",
770
            "cuda",
771
            torch._C._get_privateuse1_backend_name(),
772
        ]:
773
            self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
774

775
        def func(scores, mask, value):
776
            return scores.masked_fill(mask, value)
777

778
        scores_t = torch.tensor([1, 2, 3, 4], device=device)
779
        mask_t = torch.tensor([True, True, True, True], device=device)
780
        value_t = torch.tensor(0, dtype=scores_t.dtype)
781
        cfunc = make_fx(func, decomposition_table=decomposition_table)
782
        fx_g = cfunc(scores_t, mask_t, value_t)
783
        self.assertExpectedInline(
784
            fx_g.code.strip(),
785
            """\
786
def forward(self, scores_1, mask_1, value_1):
787
    where = torch.ops.prims.where.default(mask_1, value_1, scores_1);  mask_1 = value_1 = scores_1 = None
788
    return where""",
789
        )
790

791
    class DecompCrossRefMode(TorchDispatchMode):
792
        def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
793
            self.test_case = test_case
794
            self.saved_precision = saved_precision
795
            self.saved_rel_tol = saved_rel_tol
796
            self.test_dtype = dtype
797
            self.run_all = run_all
798

799
            # We check the correctness of each decomposition right after running it.
800
            # So, when we encounter a decomposition, we run the function normally, and
801
            # then run the decomposition, and ensure they're identical.
802
            self.called = set()
803
            self.decomposed = set()
804

805
        def __torch_dispatch__(self, func, types, args=(), kwargs=None):
806
            self.test_case.precision = self.saved_precision
807
            self.test_case.rel_tol = self.saved_rel_tol
808

809
            self.called.add(func)
810
            all_called[func] += 1
811

812
            # Stuff we shouldn't bother testing
813
            # (TODO: remove detach from the decomp table?)
814
            # N.b. Testing in-place ops would need dedicated logic
815
            in_place = func.name()[-1] == "_"
816
            ignored_ops = [
817
                torch.ops.aten.detach.default,
818
                # non-deterministic ops
819
                torch.ops.aten.empty.memory_format,
820
                torch.ops.aten.empty_like.default,
821
                torch.ops.aten.new_empty.default,
822
                torch.ops.aten.empty_strided.default,
823
                torch.ops.aten.new_empty_strided.default,
824
                torch.ops.aten.randn.default,
825
                torch.ops.aten.native_dropout.default,
826
            ]
827
            if (
828
                func not in decomposition_table
829
                or func in ignored_ops
830
                or torch.Tag.nondeterministic_seeded in func.tags
831
                or any_unsupported(args, kwargs)
832
                or in_place
833
            ):
834
                return func(*args, **kwargs)
835

836
            self.decomposed.add(func)
837
            all_decomposed.add(func)
838

839
            # We take 2 main strategies for verifying correctness/numerical stability of decompositions
840
            # The first one is simply tolerance checking between decomp_out and pytorch_out
841
            # However, for fp16/bf16 and reductions, this becomes very
842
            # finicky, as there are not many guarantees we can make.
843
            # So, for fp16/bf16, we instead compare the difference of
844
            # {decomp_out, pytorch_out_64} and {pytorch_out,
845
            # pytorch_out_64}. In other words, we compare how far the
846
            # decomposition and pytorch are from the "ground truth" (i.e.
847
            # fp64). If the decomposition results in more error, we error
848

849
            # We also decompose the decomposition recursively for
850
            # further coverage, as some paths not be exercised directly by
851
            # OpInfos (sadly) but just by other ops
852

853
            decomposition = decomposition_table[func]
854

855
            do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16]
856
            if self.run_all:
857
                # Execute recursively via DFS, to find the root of a possible error first
858
                with self:
859
                    decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
860
            else:
861
                decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
862

863
            # At this stage we should not be decomposing an in-place op
864
            # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops
865
            #  because decompositions are run after functionalisation and we would not like them to
866
            #  de-functionalise the graph, as that would break AoTAutograd
867
            # We run the real function *after* the decomposition to make sure that the
868
            # decomposition does not modify any of the inputs in-place. If it does
869
            # real_out should be differen than decom_out so we should catch this
870
            real_out_unflat = func(*args, **kwargs)
871
            real_out = pytree.tree_leaves(real_out_unflat)
872

873
            assert len(real_out) == len(decomp_out)
874

875
            if do_relative_check:
876
                upcast = partial(upcast_tensor, dtype=torch.float64)
877
                real_out_double, _ = tree_flatten(
878
                    func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
879
                )
880
                for i, (orig, decomp, ref) in enumerate(
881
                    zip(real_out, decomp_out, real_out_double)
882
                ):
883
                    if not isinstance(orig, torch.Tensor):
884
                        assert type(orig) == type(decomp)
885
                        assert orig == decomp
886
                        continue
887
                    op_assert_ref(
888
                        self.test_case,
889
                        func,
890
                        self.test_dtype,
891
                        i,
892
                        orig,
893
                        decomp,
894
                        ref,
895
                        args,
896
                        kwargs,
897
                    )
898
            else:
899
                for orig, decomp in zip(real_out, decomp_out):
900
                    if not isinstance(orig, torch.Tensor):
901
                        assert type(orig) == type(decomp)
902
                        assert orig == decomp
903
                        continue
904
                    op_assert_equal(
905
                        self.test_case,
906
                        func,
907
                        self.test_dtype,
908
                        orig,
909
                        decomp,
910
                        args,
911
                        kwargs,
912
                    )
913

914
            return real_out_unflat
915

916
    def check_decomposed(self, aten_name, mode):
917
        self.assertTrue(
918
            any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
919
            msg=(
920
                f"aten.{aten_name} was not decomposed, saw calls for: "
921
                f"{', '.join(map(str, list(mode.called)))}. If your op is  "
922
                f"CompositeImplicitAutograd you should skip this test "
923
                f"by updating CROSS_REF_EXCLUDE_SET."
924
            ),
925
        )
926

927
    @skipIfTorchDynamo("Test does not work with TorchDynamo")
928
    def do_cross_ref(self, device, dtype, op, *, run_all):
929
        test_keys = [
930
            (torch.device(device).type, dtype, op.name),
931
            (None, dtype, op.name),
932
            (None, None, op.name),
933
        ]
934
        if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
935
            self.skipTest(f"{op.name} in {dtype} not supported")
936

937
        skip_decomp_vjp = any(
938
            key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys
939
        )
940

941
        requires_grad = (
942
            op.supports_autograd
943
            and dtype in op.supported_backward_dtypes(torch.device(device).type)
944
            # TODO: OpInfo really ought to error out for this case, but it's
945
            # not exercised in test_ops_gradients atm.  The problem is not
946
            # complex32 per-se (which is supported by data movement only ops)
947
            # but that when we do backwards we expect other ops like add to work
948
            and not dtype == torch.complex32
949
        )
950
        samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
951

952
        aten_name = op.decomp_aten_name or op.aten_name
953

954
        func = op.get_op()
955

956
        def run_without_python_dispatcher(mode):
957
            return any(
958
                isinstance(op, torch._ops.OpOverload)
959
                and op.has_kernel_for_dispatch_key(
960
                    DispatchKey.CompositeImplicitAutograd
961
                )
962
                for op in mode.decomposed.union([func])
963
            )
964

965
        for sample_input in samples:
966
            if requires_grad:
967
                fn, primals = normalize_op_input_output(func, sample_input)
968
                primals = tree_map(
969
                    lambda x: x if isinstance(x, torch.Tensor) else x, primals
970
                )
971

972
                # Once https://github.com/pytorch/pytorch/pull/75965/ I can
973
                # store the called list on the mode object instance and no
974
                # explicit clearing is necessary as I will create a fresh mode
975
                # for each region
976
                with self.DecompCrossRefMode(
977
                    self, self.precision, self.rel_tol, dtype, run_all
978
                ) as mode, enable_python_dispatcher():
979
                    decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
980
                if run_without_python_dispatcher(mode):
981
                    # without this check, incorrect decomps at the python dispatcher level can still pass because
982
                    # they're checking aten decomps at the torch_dispatch level.
983
                    with self.DecompCrossRefMode(
984
                        self, self.precision, self.rel_tol, dtype, run_all
985
                    ) as mode:
986
                        decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
987
                if aten_name in decomposition_names:
988
                    self.check_decomposed(aten_name, mode)
989

990
                if not skip_decomp_vjp and (
991
                    op.aten_backward_name in decomposition_names or run_all
992
                ):
993
                    cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
994

995
                    with self.DecompCrossRefMode(
996
                        self, self.precision, self.rel_tol, dtype, run_all
997
                    ) as mode, enable_python_dispatcher():
998
                        decomp_vjp_fn(cotangents)
999
                    if run_without_python_dispatcher(mode):
1000
                        # without this check, incorrect decomps at the python dispatcher level can still pass because
1001
                        # they're checking aten decomps at the torch_dispatch level.
1002
                        with self.DecompCrossRefMode(
1003
                            self, self.precision, self.rel_tol, dtype, run_all
1004
                        ) as mode:
1005
                            decomp_vjp_fn(cotangents)
1006
                    if not run_all:
1007
                        self.check_decomposed(op.aten_backward_name, mode)
1008

1009
            elif aten_name in decomposition_names or run_all:
1010
                args = [sample_input.input] + list(sample_input.args)
1011
                kwargs = sample_input.kwargs
1012
                # A failure here might be because the decomposition for the op is wrong or because a
1013
                # decomposition used by the particular op is wrong.
1014
                with self.DecompCrossRefMode(
1015
                    self, self.precision, self.rel_tol, dtype, run_all
1016
                ) as mode, enable_python_dispatcher():
1017
                    func(*args, **kwargs)
1018

1019
                if run_without_python_dispatcher(mode):
1020
                    # without this check, incorrect decomps at the python dispatcher level can still pass because
1021
                    # they're checking aten decomps at the torch_dispatch level.
1022
                    with self.DecompCrossRefMode(
1023
                        self, self.precision, self.rel_tol, dtype, run_all
1024
                    ) as mode:
1025
                        func(*args, **kwargs)
1026

1027
                if not run_all:
1028
                    self.check_decomposed(aten_name, mode)
1029
            else:
1030
                assert op.supports_autograd
1031
                self.skipTest(
1032
                    "only backwards is decomposed, but dtype doesn't support AD"
1033
                )
1034

1035

1036
instantiate_device_type_tests(TestDecomp, globals())
1037

1038

1039
class DecompOneOffTests(TestCase):
1040
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1041
    @onlyNativeDeviceTypes
1042
    @skipIfCrossRef
1043
    def test_contiguous_softmax(self, device):
1044
        size = (2, 4, 3, 3)
1045
        stride = (9, 18, 3, 1)
1046
        dtype = torch.float32
1047

1048
        x = torch.randn(size, dtype=dtype, device=device)
1049
        x = torch.as_strided(x, size, stride)
1050

1051
        ref = torch.ops.aten._softmax(x, -1, False)
1052
        res = torch._decomp.decompositions._softmax(x, -1, False)
1053
        self.assertEqual(ref.stride(), res.stride())
1054

1055
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1056
    @onlyNativeDeviceTypes
1057
    @skipIfCrossRef
1058
    def test_contiguous_log_softmax(self, device):
1059
        size = (2, 4, 3, 3)
1060
        stride = (9, 18, 3, 1)
1061

1062
        dtype = torch.float32
1063
        x = torch.randn(size, dtype=dtype, device=device)
1064
        x = torch.as_strided(x, size, stride)
1065

1066
        ref = torch.ops.aten._log_softmax(x, -1, False)
1067
        res = torch._decomp.decompositions._log_softmax(x, -1, False)
1068
        self.assertEqual(ref.stride(), res.stride())
1069

1070
    @onlyCUDA
1071
    def test_exponential_non_inf(self, device):
1072
        inp = torch.empty((4, 400, 256), device=device)
1073

1074
        with torch._dynamo.utils.preserve_rng_state():
1075
            exp_ref = inp.exponential_()
1076
        exp = torch._refs.exponential(inp)
1077

1078
        self.assertEqual(exp, exp_ref)
1079
        self.assertFalse(exp.isinf().any())
1080

1081
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1082
    @skipIfCrossRef
1083
    @onlyCUDA
1084
    def test_amp_batch_norm_backward(self):
1085
        device = "cuda"
1086
        grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1087
        x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1088
        weight = torch.randn((2,), dtype=torch.float32, device=device)
1089
        rmean = torch.randn((2,), dtype=torch.float32, device=device)
1090
        rvar = torch.randn((2,), dtype=torch.float32, device=device)
1091
        mean = torch.randn((0,), dtype=torch.float32, device=device)
1092

1093
        ref = torch.ops.aten.native_batch_norm_backward(
1094
            grad_out,
1095
            x,
1096
            weight,
1097
            rmean,
1098
            rvar,
1099
            mean,
1100
            mean,
1101
            False,
1102
            1e-05,
1103
            [True, True, True],
1104
        )
1105
        res = torch._decomp.decompositions.native_batch_norm_backward(
1106
            grad_out,
1107
            x,
1108
            weight,
1109
            rmean,
1110
            rvar,
1111
            mean,
1112
            mean,
1113
            False,
1114
            1e-05,
1115
            [True, True, True],
1116
        )
1117
        for a, b in zip(ref, res):
1118
            self.assertEqual(a.stride(), b.stride())
1119
            self.assertEqual(a.dtype, b.dtype)
1120

1121
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1122
    @onlyNativeDeviceTypes
1123
    @skipIfCrossRef
1124
    def test_elu_backward(self, device):
1125
        size = (2, 4, 3, 3)
1126
        dtype = torch.float32
1127
        grad_out = torch.randn(size, dtype=dtype, device=device)
1128
        out = torch.randn(size, dtype=dtype, device=device)
1129

1130
        ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out)
1131
        res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out)
1132
        self.assertEqual(ref, res)
1133

1134
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1135
    @onlyNativeDeviceTypes
1136
    @skipIfCrossRef
1137
    def test_threshold_backward_dtype(self, device):
1138
        grad = torch.randint(10, (4,), device=device)
1139
        input_tensor = torch.randint(10, (4,), device=device)
1140

1141
        ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1)
1142
        res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1)
1143
        self.assertEqual(ref.dtype, res.dtype)
1144

1145
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1146
    @onlyNativeDeviceTypes
1147
    @skipIfCrossRef
1148
    def test_weight_norm_interface(self, device):
1149
        g = torch.randn((3, 10, 10), device=device)
1150
        v = torch.randn((1, 1, 10), device=device)
1151

1152
        ref = torch.ops.aten._weight_norm_interface(g, v, 2)
1153
        res = torch._decomp.decompositions._weight_norm_interface(g, v, 2)
1154
        self.assertTrue(torch.allclose(ref[0], res[0]))
1155
        self.assertTrue(torch.allclose(ref[1], res[1]))
1156

1157
        inp = torch.rand([30, 10], device=device)
1158
        inp2 = torch.rand([30, 1], device=device)
1159

1160
        self.assertEqual(
1161
            torch.ops.aten._weight_norm_interface(inp, inp2),
1162
            torch._decomp.decompositions._weight_norm_interface(inp, inp2),
1163
        )
1164

1165
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1166
    @onlyCPU
1167
    @skipIfCrossRef
1168
    @skipOps(
1169
        "DecompOneOffTests",
1170
        "test_sdpa",
1171
        [
1172
            xfail(
1173
                "nn.functional.scaled_dot_product_attention",
1174
                dtypes=[torch.half],
1175
            ),
1176
        ],
1177
    )
1178
    @ops(_sdpa_op_info)
1179
    def test_sdpa(self, device, dtype, op):
1180
        # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
1181
        # add support for float16 over there we should update this test as well.
1182

1183
        class ScaledDotProductAttention(torch.nn.Module):
1184
            def __init__(self) -> None:
1185
                super().__init__()
1186

1187
            def forward(
1188
                self, query_layer, key_layer, value_layer, mask=None, is_causal=True
1189
            ):
1190
                attn_output = op(
1191
                    query_layer,
1192
                    key_layer,
1193
                    value_layer,
1194
                    attn_mask=mask,
1195
                    dropout_p=0.0,
1196
                    is_causal=is_causal,
1197
                )
1198
                return attn_output
1199

1200
        query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1201
        key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1202
        value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1203
        masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)]
1204

1205
        atol, rtol = dtype_precisions[dtype]
1206

1207
        for mask in masks:
1208
            is_causal = mask is None
1209
            attention = ScaledDotProductAttention()
1210
            decomposed_res = (
1211
                torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
1212
                    query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
1213
                )
1214
            )
1215
            eager_res = op(
1216
                query_layer,
1217
                key_layer,
1218
                value_layer,
1219
                attn_mask=mask,
1220
                dropout_p=0.0,
1221
                is_causal=is_causal,
1222
            )
1223

1224
            self.assertTrue(
1225
                torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
1226
            )
1227

1228

1229
instantiate_device_type_tests(DecompOneOffTests, globals())
1230

1231

1232
class HasDecompTest(TestCase):
1233
    def setUp(self):
1234
        super().setUp()
1235
        self.maxDiff = None
1236

1237
    @staticmethod
1238
    def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool:
1239
        has_tensor_arg = any(
1240
            "Tensor" in str(a.type)
1241
            for a in itertools.chain(op._schema.arguments, op._schema.returns)
1242
        )
1243
        if not has_tensor_arg:
1244
            return False
1245

1246
        try:
1247
            # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions
1248
            return not op.has_kernel_for_dispatch_key(
1249
                DispatchKey.CompositeImplicitAutograd
1250
            )
1251
        except RuntimeError as e:
1252
            # has_key fails for some jit-registered ops, which shouldn't be
1253
            # relevant here anyway
1254
            if "does not exist" in str(e):
1255
                return False
1256
            raise
1257

1258
    def test_has_decomposition(self):
1259
        def all_aten_overloads():
1260
            for name in torch._C._dispatch_get_all_op_names():
1261
                if not name.startswith("aten::"):
1262
                    continue
1263

1264
                name = name[6:]
1265
                if "." in name:
1266
                    packet_name, overload_name = name.split(".")
1267
                else:
1268
                    packet_name, overload_name = name, "default"
1269

1270
                packet = getattr(aten, packet_name)
1271
                assert isinstance(packet, torch._ops.OpOverloadPacket)
1272
                op = getattr(packet, overload_name)
1273
                yield op
1274

1275
        # This is for operators that are only registered in some CI
1276
        # configurations, so would cause the test to fail
1277
        allow_list = {aten.get_gradients.default}
1278

1279
        overloads_wanting_decomp = {
1280
            op for op in all_aten_overloads() if self._can_appear_in_trace(op)
1281
        }
1282
        ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
1283
        ops_missing_decomp -= allow_list
1284
        self.assertExpected(
1285
            "".join(sorted(op.name() + "\n" for op in ops_missing_decomp))
1286
        )
1287

1288
    def test_aten_core_operators(self):
1289
        # If a decomposition isn't included in the core decompositions,
1290
        # then it must decompose a core ATen operator.
1291
        #
1292
        # See NOTE [Core ATen Ops]
1293
        #
1294
        # If this test fails then either:
1295
        # - Add the decomposition to torch._decomp.core_aten_decompositions,
1296
        #   if decomposition should be used by inductor (not a core operator).
1297
        # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of
1298
        #   core ATen operators (and inductor will not use the decomposition).
1299

1300
        # Some decompositions are registered for CompositeImplicitAutograd
1301
        # operators, which never appear in AOTAutograd's graph so are never used.
1302
        useful_decomps = {
1303
            op
1304
            for op in decomposition_table.keys()
1305
            if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op)
1306
        }
1307
        core_decomps = torch._decomp.core_aten_decompositions().keys()
1308
        core_aten_ops = useful_decomps - core_decomps
1309
        self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))
1310

1311

1312
if __name__ == "__main__":
1313
    run_tests()
1314

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.