pytorch

Форк
0
/
test_meta.py 
1664 строки · 65.6 Кб
1
# Owner(s): ["module: decompositions"]
2

3
import itertools
4
import torch
5
import os
6
import numpy as np
7
from enum import Enum
8
from torch.overrides import resolve_name
9
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
10
from torch.utils import _pytree as pytree
11
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
12
import torch.utils._python_dispatch
13
from torch._dispatch.python import enable_python_dispatcher
14
from torch._ops import OpOverload, OpOverloadPacket
15
from torch.testing import make_tensor
16
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
17
from torch.testing._internal.common_utils import (
18
    TestCase,
19
    skipIfCrossRef,
20
    skipIfTorchDynamo,
21
    suppress_warnings,
22
    TEST_WITH_ASAN,
23
    TEST_WITH_TORCHDYNAMO,
24
    run_tests,
25
    dtype_abbrs,
26
    parametrize
27
)
28
from torch.testing._internal.common_device_type import (
29
    ops,
30
    instantiate_device_type_tests,
31
    onlyCUDA,
32
    onlyCPU,
33
    OpDTypes,
34
)
35
from torch.testing._internal.common_methods_invocations import (
36
    binary_ufuncs, op_db, foreach_unary_op_db, foreach_binary_op_db,
37
    foreach_pointwise_op_db, foreach_reduce_op_db, foreach_other_op_db)
38
from torch.testing._internal.opinfo.core import S, SampleInput
39
from torchgen.yaml_utils import YamlLoader
40
from torchgen.model import OperatorName
41

42
import copy
43
import sys
44
import yaml
45
import atexit
46
import re
47
from collections import defaultdict
48
from collections.abc import Iterable
49
import unittest
50
import warnings
51
import weakref
52
from functools import partial, wraps
53

54
bf16 = torch.bfloat16
55
f64 = torch.float64
56
f32 = torch.float32
57
f16 = torch.float16
58
c32 = torch.complex32
59
c64 = torch.complex64
60
c128 = torch.complex128
61
i8 = torch.int8
62
i16 = torch.int16
63
i32 = torch.int32
64
i64 = torch.int64
65
b8 = torch.bool
66
u8 = torch.uint8
67

68
foreach_op_db = (
69
    foreach_unary_op_db +
70
    foreach_binary_op_db +
71
    foreach_pointwise_op_db +
72
    foreach_reduce_op_db +
73
    foreach_other_op_db
74
)
75

76

77
class TestMetaConverter(TestCase):
78
    def assertSameVersionCounter(self, m1, m2):
79
        # Cannot easily test m1 and m2 have same storage due to
80
        # lack of Storage bindings.  Use version counter.
81
        vc = m1._version
82
        self.assertEqual(m2._version, vc)
83
        # Doing it this way ensures that we get VC bump even with leaves
84
        with torch.no_grad():
85
            m1._base.add_(3)
86
        self.assertNotEqual(m1._version, vc)
87
        self.assertEqual(m2._version, m1._version)
88

89
    def assertMetadataMatches(self, m1, m2):
90
        assert_metadata_eq(self.assertEqual, m1, m2)
91

92
    def test_view_of_non_leaf(self):
93
        x = torch.randn(4, requires_grad=True)
94
        y = x.neg()
95
        z1 = y[:]
96
        z2 = y[:]
97
        to_meta = MetaConverter()
98
        m1 = to_meta(z1)
99
        m2 = to_meta(z2)
100

101
        # check the test is actually testing what it claims
102
        self.assertTrue(m1._is_view())
103
        self.assertFalse(m1._base.is_leaf)
104

105
        self.assertIsNot(m1, m2)
106
        self.assertMetadataMatches(m1, z1)
107
        self.assertMetadataMatches(m2, z2)
108
        self.assertSameVersionCounter(m1, m2)
109

110
    def test_view_of_leaf(self):
111
        x = torch.randn(4, requires_grad=True)
112
        z1 = x[:]
113
        z2 = x[:]
114
        to_meta = MetaConverter()
115
        m1 = to_meta(z1)
116
        m2 = to_meta(z2)
117

118
        # check the test is actually testing what it claims
119
        self.assertTrue(m1._is_view())
120
        self.assertTrue(m1._base.is_leaf)
121

122
        self.assertIsNot(m1, m2)
123
        self.assertMetadataMatches(m1, z1)
124
        self.assertMetadataMatches(m2, z2)
125
        self.assertSameVersionCounter(m1, m2)
126

127
    def test_view_of_view_of_leaf(self):
128
        x = torch.randn(8)
129
        y = x.view(2, 4)
130
        y.requires_grad = True
131
        z = y.view(2, 2, 2)
132

133
        to_meta = MetaConverter()
134
        mx = to_meta(x)
135
        mz = to_meta(z)
136

137
        self.assertFalse(z.is_leaf)
138

139
        self.assertMetadataMatches(mx, x)
140
        self.assertMetadataMatches(mz, z)
141

142
    def test_leaf(self):
143
        x = torch.randn(4, requires_grad=True)
144
        to_meta = MetaConverter()
145
        m = to_meta(x)
146

147
        # check the test is actually testing what it claims
148
        self.assertTrue(m.is_leaf)
149
        self.assertTrue(m.requires_grad)
150

151
        self.assertMetadataMatches(m, x)
152

153
    def test_non_leaf(self):
154
        x = torch.randn(4, requires_grad=True)
155
        y = x.neg()
156
        to_meta = MetaConverter()
157
        m = to_meta(y)
158

159
        # check the test is actually testing what it claims
160
        self.assertFalse(m.is_leaf)
161
        self.assertTrue(m.requires_grad)
162

163
        self.assertMetadataMatches(m, y)
164

165
    def test_requires_grad_false(self):
166
        x = torch.randn(4, requires_grad=False)
167
        to_meta = MetaConverter()
168
        m = to_meta(x)
169

170
        # check the test is actually testing what it claims
171
        self.assertFalse(m.requires_grad)
172

173
        self.assertMetadataMatches(m, x)
174

175
    def test_channels_last(self):
176
        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last)
177
        to_meta = MetaConverter()
178
        m = to_meta(x)
179

180
        # check the test is actually testing what it claims
181
        self.assertTrue(m.is_leaf)
182

183
        self.assertMetadataMatches(m, x)
184

185
    def test_channels_last_leaf(self):
186
        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
187
        to_meta = MetaConverter()
188
        m = to_meta(x)
189

190
        # check the test is actually testing what it claims
191
        self.assertTrue(m.requires_grad)
192
        self.assertTrue(m.is_leaf)
193

194
        self.assertMetadataMatches(m, x)
195

196
    def test_channels_last_non_leaf(self):
197
        x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
198
        y = x + 2
199

200
        # sanity
201
        self.assertEqual(x.stride(), y.stride())
202
        self.assertFalse(y.is_leaf)
203

204
        to_meta = MetaConverter()
205
        m = to_meta(y)
206

207
        # check the test is actually testing what it claims
208
        self.assertTrue(m.requires_grad)
209
        self.assertFalse(m.is_leaf)
210

211
        self.assertMetadataMatches(m, y)
212

213
        # Check that we can autograd with m as input without erroring;
214
        # see https://github.com/pytorch/pytorch/issues/87956
215
        loss = m.sum()
216
        torch.autograd.grad(loss, m)
217

218
    def test_empty_strided_non_dense_leaf(self):
219
        x = torch.empty_strided((2, 2), (4, 2), requires_grad=True)
220

221
        to_meta = MetaConverter()
222
        m = to_meta(x)
223

224
        # check the test is actually testing what it claims
225
        self.assertTrue(m.requires_grad)
226
        self.assertTrue(m.is_leaf)
227

228
        self.assertMetadataMatches(m, x)
229

230
    def test_view_mutate(self):
231
        x = torch.zeros(4)
232
        y = x.view(2, 2)
233

234
        to_meta = MetaConverter()
235
        m = to_meta(y)
236

237
        y.add_(torch.randn(2, 2, requires_grad=True))
238
        m.add_(torch.randn(2, 2, device='meta', requires_grad=True))
239

240
    def test_non_leaf_torture(self):
241
        x = torch.empty(20, requires_grad=True)
242
        with torch.no_grad():
243
            x.set_(x.storage(), 10, (2,), (2,))
244

245
        to_meta = MetaConverter()
246
        m = to_meta(x)
247

248
        # check the test is actually testing what it claims
249
        self.assertTrue(m.requires_grad)
250
        self.assertTrue(m.is_leaf)
251

252
        self.assertMetadataMatches(m, x)
253

254
    # NB: complex stuff is not actually exercised right now because
255
    # we have a blanket exclusion for complex conversion
256

257
    def test_view_as_real(self):
258
        x = torch.randn(4, dtype=torch.complex64)
259
        y = torch.view_as_real(x)
260
        m = MetaConverter()(y)
261
        self.assertMetadataMatches(m, y)
262

263
    def test_complex_noncontiguous_bug(self):
264
        x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :]
265
        m = MetaConverter()(x)
266
        self.assertMetadataMatches(m, x)
267

268
    def test_view_as_complex(self):
269
        x = torch.randn((4, 2), dtype=torch.float32)
270
        y = torch.view_as_complex(x)
271
        m = MetaConverter()(y)
272
        self.assertMetadataMatches(m, y)
273

274
    def test_view_dtype(self):
275
        x = torch.randn(4, dtype=torch.float32)
276
        y = x.view(dtype=torch.int32)
277
        m = MetaConverter()(y)
278
        self.assertMetadataMatches(m, y)
279

280
    def test_imag(self):
281
        x = torch.randn(4, dtype=torch.complex64)
282
        y = x.imag
283
        m = MetaConverter()(y)
284
        self.assertMetadataMatches(m, y)
285

286
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
287
    def test_weakref(self):
288
        x = torch.randn(4, 4, 4)
289
        m = MetaConverter()
290
        y = m(x)
291
        z = m(x)
292
        self.assertIs(y, z)
293
        self.assertEqual(len(m.tensor_memo), 1)
294
        self.assertEqual(len(m.storage_memo), 1)
295
        del x
296
        self.assertEqual(len(m.tensor_memo), 0)
297
        m.check_for_expired_weak_storages()
298
        self.assertEqual(len(m.storage_memo), 0)
299
        li = []
300
        r = []
301
        for i in range(4):
302
            li.append(torch.rand([i]))
303
            r.append(m(li[-1]))
304
        self.assertEqual(len(m.tensor_memo), 4)
305
        del li
306
        self.assertEqual(len(m.tensor_memo), 0)
307
        m.check_for_expired_weak_storages()
308
        self.assertEqual(len(m.storage_memo), 0)
309

310
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
311
    def test_tensor_outlives_converter(self):
312
        m = MetaConverter()
313
        ref = weakref.ref(m)
314
        x = torch.randn([4, 4])
315
        y = m(x)
316
        del m
317
        self.assertIs(ref(), None)
318

319
aten = torch.ops.aten
320

321
CHECK_STRIDES = {
322
    torch.Tensor.__getitem__,
323
}
324

325
CHECK_ALL_STRIDES = {
326
    aten.unsqueeze.default
327
}
328

329
CHECK_STRIDES_SKIPS = {
330
    aten._conj_physical.default,
331
    aten._fft_c2c.default,
332
    aten._fft_c2r.default,
333
    aten._fft_r2c.default,
334
    aten._linalg_svd.default,
335
    aten.binary_cross_entropy.default,
336
    aten.complex.default,
337
    aten.polar.default,
338
    aten.copysign.Tensor,
339
    aten.div.Tensor_mode,
340
    aten.floor_divide.default,
341
    aten.heaviside.default,
342
    aten.lerp.Scalar,
343
    aten.lerp.Tensor,
344
    aten.logaddexp.default,
345
    aten.logical_and.default,
346
    aten.logical_or.default,
347
    aten.logical_xor.default,
348
    aten.pow.Scalar,
349
    aten.prelu.default,
350
    aten.special_xlog1py.default,
351
    aten.xlogy.Tensor,
352
    aten.nll_loss2d_forward.default,
353

354
    # channel_last and channel_last_3d related failures
355
    aten.convolution.default,
356

357
    # following ops fails if include_storage_offset = True, but these are a bit edge casey
358
    # we should still fix them, leaving them here for tracking.
359
    # aten._reshape_alias.default,  # repro with test_dispatch_symbolic_meta_outplace_all_strides_matmul_cuda_float32
360
    # aten.view.default,  # repro with test_dispatch_symbolic_meta_outplace_all_strides_unflatten_cuda_float32
361
}
362

363
CHECK_CONJ_SKIPS = {
364
    # The conj bit is not copied, see:
365
    # https://github.com/pytorch/pytorch/pull/101836
366
    aten.linalg_lu_solve.out,
367
}
368

369
class CheckStrides(Enum):
370
    NONE = 0
371
    SIGNIFICANT = 1
372
    ALL = 2
373

374
def should_check_strides(func):
375
    if func in CHECK_ALL_STRIDES:
376
        return CheckStrides.ALL
377
    if func in CHECK_STRIDES:
378
        return CheckStrides.SIGNIFICANT
379
    if func in CHECK_STRIDES_SKIPS:
380
        return CheckStrides.NONE
381
    if not isinstance(func, torch._ops.OpOverload):
382
        return CheckStrides.NONE
383
    # Prims are expected to model strides correctly
384
    if func.namespace == "prims":
385
        return CheckStrides.SIGNIFICANT
386
    # Check if it's a view, by testing if any of the returns have
387
    # a non-empty alias set
388
    if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info):
389
        return CheckStrides.SIGNIFICANT
390
    # TODO: check for TensorIterator
391
    return CheckStrides.SIGNIFICANT
392

393
def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
394
    flat_meta_rs = pytree.tree_leaves(meta_rs)
395
    flat_rs = pytree.tree_leaves(rs)
396
    test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
397
    for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs):
398
        def test_assert(cond, msg):
399
            if not cond:
400
                raise RuntimeError(f"output {i}: {msg_callable(msg)}")
401
        if not isinstance(r, torch.Tensor):
402
            continue
403
        test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
404
        test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
405
        test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
406
        # See https://github.com/pytorch/pytorch/issues/78050
407
        if should_check_strides(func) == CheckStrides.ALL:
408
            same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
409
            test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
410
        elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
411
            same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
412
            test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
413
        test_assert(
414
            meta_r.storage_offset() == r.storage_offset(),
415
            f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
416
        test_assert(meta_r.requires_grad == r.requires_grad,
417
                    f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
418
        if func not in CHECK_CONJ_SKIPS:
419
            test_assert(meta_r.is_conj() == r.is_conj(),
420
                        f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
421
        test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
422

423

424
# This environment variable controls whether or not we print expected failure
425
# lists at the end of a test suite run.  The intended usage looks like this:
426
#
427
# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_meta.py` on a CUDA build
428
#    of PyTorch that has LAPACK/MAGMA installed.  You can filter `-k test_meta`
429
#    or `-k test_dispatch_meta` to only focus on one or another list
430
# 2. Given the printed skip/xfail list, add them to the corresponding lists;
431
#    torch.* entries go in meta_function and aten.* entries go in meta_dispatch.
432
#    If there are preexisting entries, you need to merge in the entries.
433
#
434
# This is somewhat manual but typically you shouldn't need to do this, unless
435
# you've made a major change (e.g., added a new dtype to PyTorch) and need to
436
# refresh the lists.  If you want to do it from scratch, just clear out the
437
# preexisting lists before running.
438
#
439
# WARNING: Python dict literals will silently ignore duplicate keys
440
COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1'
441

442
seen_succeeded = {}
443
seen_failed = {}
444
failed_reasons = defaultdict(set)
445
def print_seen():
446
    expected_failures = []
447
    skips = []
448

449
    def fmt_dtypes(dtypes):
450
        r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes))
451
        return '{' + r + '}'
452

453
    for op, failed_dtypes in seen_failed.items():
454
        ops = resolve_name(op)
455
        succeeded_dtypes = seen_succeeded.get(op, set())
456
        expected_failures_dtypes = failed_dtypes - succeeded_dtypes
457
        skips_dtypes = failed_dtypes & succeeded_dtypes
458
        reasons = ""
459
        if failed_reasons[op]:
460
            reasons = "  # " + ", ".join(sorted(failed_reasons[op]))
461
        if expected_failures_dtypes:
462
            expected_failures.append(f"    {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}")
463
        if skips_dtypes:
464
            skips.append(f"    {ops}: {fmt_dtypes(skips_dtypes)},")
465
    expected_failures.sort()
466
    skips.sort()
467
    nl = '\n'
468
    print(f"""\
469
expected_failures = {{
470
{nl.join(expected_failures)}
471
}}
472

473
skips = {{
474
{nl.join(skips)}
475
}}
476
""")
477
if COLLECT_EXPECT:
478
    atexit.register(print_seen)
479

480
# Success forces pass; failure forces fail; skip unconditionally skips testing
481
TestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP"))
482

483
# unlike print produce strides
484
def verbose_print(e):
485
    class Lit:
486
        def __init__(self, s):
487
            self.s = s
488

489
        def __repr__(self):
490
            return self.s
491

492
    def go(t):
493
        if isinstance(t, torch.Tensor):
494
            return Lit(f"{t} stride={t.stride()}")
495
        else:
496
            return t
497

498
    return repr(tree_map(go, e))
499

500
def run_meta_crossref(
501
    test_case,
502
    test_expect,
503
    func,
504
    args,
505
    kwargs,
506
    *,
507
    dtype,
508
    device_type,
509
    run_symbolic_meta: bool
510
):
511
    to_meta = MetaConverter()
512
    do_meta = test_expect is not TestExpect.SKIP
513
    if do_meta:
514
        try:
515
            meta_args = tree_map(to_meta, args)
516
            meta_kwargs = tree_map(to_meta, kwargs)
517
        except Exception as e:
518
            raise RuntimeError(
519
                f"failed to convert args to meta; "
520
                f"originally (*{args}, **{kwargs})") from e
521
    try:
522
        rs = func(*args, **kwargs)
523
    except Exception as e:
524
        raise AssertionError("Original OpInfo is broken") from e
525

526
    # TODO: also handle cases where func raise an exception
527

528
    # For now, only attempt if we managed to convert all tensor types
529
    # (if any of them failed, we're in a mixed device situation and
530
    # this isn't well supported)
531
    if do_meta and to_meta.successful():
532
        # Special cases
533
        if func is torch.tensor_split:
534
            # Use original indices_or_sections, this argument is data dependent
535
            meta_args = (meta_args[0], args[1]) + meta_args[2:]
536
        elif func is torch.Tensor.__getitem__:
537
            # Ensure boolean tensors use original
538
            assert len(args) == 2
539
            flat_args = pytree.tree_leaves(args[1])
540
            flat_meta_args, spec = tree_flatten(meta_args[1])
541
            flat_new_args = []
542
            for a, ma in zip(flat_args, flat_meta_args):
543
                flat_new_args.append(a if isinstance(a, torch.Tensor) and a.dtype in [torch.int8, torch.bool] else ma)
544
            meta_args = (meta_args[0], tree_unflatten(flat_new_args, spec))
545
        elif func in (torch.ops.aten.repeat_interleave.Tensor, torch.ops.aten.repeat_interleave.Tensor_out):
546
            if kwargs.get("output_size", None) is None:
547
                meta_args = args
548
            if func is torch.ops.aten.repeat_interleave.Tensor_out:
549
                meta_kwargs["out"] = kwargs["out"]
550
        elif func in (torch.ops.aten.index.Tensor, torch.ops.aten.index.Tensor_out):
551
            # Don't convert boolean tensors to meta as they will have nonzero
552
            # called on them
553
            indices = []
554
            for meta_index, real_index in zip(meta_args[1], args[1]):
555
                if meta_index is not None and meta_index.dtype in [torch.int8, torch.bool]:
556
                    indices.append(real_index)
557
                else:
558
                    indices.append(meta_index)
559
            meta_args = (meta_args[0], indices)
560
        elif func is torch.nn.functional.ctc_loss and all([isinstance(args[2], list), isinstance(args[3], list)]):
561
            # torch.ops.aten._ctc_loss.IntList has a meta kernel but
562
            # torch.ops.aten._ctc_loss.Tensor does not
563
            test_expect = TestExpect.SUCCESS
564

565
        if kwargs.get("device", None) is not None:
566
            meta_kwargs["device"] = "meta"
567

568
        try:
569
            # Suppress warnings, this doesn't matter for test_meta.py
570
            # but it does matter if you want to use this decorator
571
            # for cross-ref testing, as some tests may be looking at
572
            # errors
573
            with warnings.catch_warnings():
574
                warnings.simplefilter("ignore")
575
                if run_symbolic_meta:
576
                    # Run the decomps and meta kernels registered
577
                    # to the python dispatcher instead of the regular dispatcher.
578
                    # This should be the same set of kernels
579
                    # that fake tensor runs in dynamic shapes mode.
580
                    with enable_python_dispatcher():
581
                        meta_rs = func(*meta_args, **meta_kwargs)
582
                else:
583
                    meta_rs = func(*meta_args, **meta_kwargs)
584
        except Exception as e:
585
            if test_expect is TestExpect.XFAILURE:
586
                return rs
587
            seen_failed.setdefault(func, set()).add(dtype)
588
            if isinstance(e, NotImplementedError):
589
                m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0])
590
                if m:
591
                    failed_reasons[func].add(m.group(1))
592
            if COLLECT_EXPECT:
593
                return rs
594
            raise RuntimeError(f"""\
595
failed to run: {resolve_name(func)}(
596
*{verbose_print(meta_args)},
597
**{verbose_print(meta_kwargs)}
598
)""") from e
599
        else:
600
            try:
601
                delim = ',\n  '
602
                assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\
603
meta disagrees with real impl:
604
{resolve_name(func)}(
605
  {delim.join(map(verbose_print, meta_args))},
606
  {delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())}
607
) = (
608
  {verbose_print(meta_rs)}
609
)
610
{msg}
611
""")
612
            except Exception:
613
                if test_expect is TestExpect.XFAILURE:
614
                    return rs
615
                seen_failed.setdefault(func, set()).add(dtype)
616
                if COLLECT_EXPECT:
617
                    return rs
618
                raise
619
            else:
620
                seen_succeeded.setdefault(func, set()).add(dtype)
621
                if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
622
                    raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}")
623

624
    return rs
625

626

627

628
RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ")
629

630
meta_function_expected_failures = {
631
    torch.Tensor.to_sparse : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
632
    torch.allclose : {f64, f16, c128, c64, bf16, f32},
633
    torch.argwhere : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
634
    torch.combinations : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
635
    torch.corrcoef : {f64, i32, c128, i64, i16, u8, c64, bf16, f16, i8, f32},
636
    torch.cov : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32, f16},
637
    torch.functional.istft : {f64, c64, c128, f32},
638
    torch.geqrf : {f64, c64, c128, f32},
639
    torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
640
    torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
641
    torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
642
    torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
643
    torch.bincount : {i32, i64, u8, i16, i8},
644
    torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
645
    torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
646
    torch.histc : {f64, f16, bf16, f32},
647
    torch.histogram : {f64, f32},
648
    torch.histogramdd : {f64, f32},
649
    torch.kthvalue : {f64, i32, i64, u8, i16, f16, bf16, i8, f32},
650
    torch.nn.functional.ctc_loss : {f64, f32},
651
    torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
652
    torch.linalg.eig : {f64, f32, c128, c64},
653
    torch.linalg.eigvals : {f64, f32, c128, c64},
654
    torch.linalg.lstsq : {f64, f32, c128, c64},
655
}
656

657
meta_function_expected_failures_conditional = {
658
    torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)),
659
}
660

661
"""
662
# This is some sample code for how we could dump these dicts into YAML
663
# file for easier reading/writing
664
import yaml
665
print(yaml.dump(
666
  {resolve_name(k): [dtype_abbrs[d] for d in v]
667
   for k, v in meta_function_expected_failures.items()}, default_flow_style=None))
668
import sys
669
sys.exit()
670
"""
671

672
meta_function_skips = {
673
    torch.Tensor.__rmatmul__ : {bf16, c128, f64, f32, f16, c64},
674
    torch.Tensor.matmul : {f64, f32, c128, c64},
675
    torch.functional.atleast_2d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
676
    torch.functional.atleast_3d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
677
    torch.functional.cartesian_prod : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
678
    torch.functional.einsum : {bf16, c128, f64, f32, f16, c64},
679
    torch.inner : {f16, bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64},
680
    torch.linalg.matrix_norm : {c128, f32, c64, f64},
681
    torch.linalg.matrix_rank : {c128, c64},
682
    torch.linalg.svd : {c128, c64},
683
    torch.matmul : {bf16, c128, f64, f32, f16, c64},
684
    torch.nanquantile : {f64, f32},
685
    torch.narrow : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c32, c64},
686
    torch.nn.functional.batch_norm : {f64, f32},
687
    torch.nn.functional.binary_cross_entropy : {bf16, f64, f32, f16},
688
    torch.nn.functional.dropout3d : {bf16, f64, f32, f16},
689
    torch.nn.functional.local_response_norm : {bf16, f64, f32, f16},
690
    torch.svd : {c128, c64},
691
    torch.take_along_dim : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
692
    torch.vstack : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
693
    torch.diff : {b8},
694
    torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
695
    torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128},
696
    torch.nn.functional.cross_entropy : {bf16, f64, f32},
697
    torch.nn.functional.nll_loss : {bf16, f64, f32},
698
    torch.linalg.cond : {c128, c64, f32, f64},
699
    torch.linalg.vecdot : {bf16, f64, f32, f16},
700
    torch.empty : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
701
    torch.Tensor.addbmm_: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
702
    torch.nn.functional.one_hot : {i64},
703
}
704

705

706
meta_function_device_expected_failures = defaultdict(dict)
707
meta_function_device_expected_failures_only_outplace = defaultdict(dict)
708
meta_function_device_skips = defaultdict(dict)
709

710
meta_function_device_expected_failures['cpu'] = {
711
    torch.native_batch_norm: {bf16, f16},
712
    torch._native_batch_norm_legit: {bf16, f16},
713
    torch.native_layer_norm: {bf16, f16},
714
}
715

716
meta_function_device_expected_failures['cuda'] = {
717
    torch.corrcoef: {bf16, f16},  # aten::_local_scalar_dense
718
    torch.cov: {f16},  # aten::_local_scalar_dense
719
    torch.functional.unique: {f16},  # aten::_unique2, aten::unique_dim
720
    torch.functional.unique_consecutive: {f16},  # aten::unique_consecutive
721
    torch.geqrf: {f32, f64},  # aten::geqrf
722
    torch.histc: {i16, i32, i64, i8},  # aten::histc, aten::histc.out
723
    torch.kthvalue: {f16},  # aten::kthvalue.values
724
}
725

726
meta_function_device_skips['cpu'] = {
727
    torch.native_batch_norm: {f32, f64},
728
    torch._native_batch_norm_legit: {f32, f64},
729
}
730

731
meta_function_device_skips['cuda'] = {
732
    torch.inner: {f16},
733
    torch.linalg.matrix_rank: {f32, f64},
734
    torch.linalg.svd: {f32, f64},
735
    torch.nn.functional.cross_entropy: {f16},
736
    torch.nn.functional.interpolate: {f16},
737
    torch.nn.functional.nll_loss: {f16},
738
    torch.svd: {f32, f64},
739
}
740

741
# This is a __torch_function__ mode that, when enabled, interposes every
742
# Torch API call and runs the operator as normal, and then reruns it
743
# with meta inputs, and then checks that everything about the output agrees.
744
# Most of the logic deals with faithfully replicating the original tensor
745
# as a meta tensor, which is nontrivial because there are a lot of subsystems
746
# that may potentially be exercised.
747
#
748
# That being said, this class is a little overkill for what it is doing in
749
# this test file (since I could have just inlined __torch_function__ on the
750
# OpInfo call, and OpInfos generally have very regular inputs), but it will be
751
# useful for more comprehensive testing e.g., as seen in
752
# https://github.com/pytorch/pytorch/pull/75994  The big benefit is it is
753
# A LOT more efficient that torch dispatch mode (at the cost of less coverage)
754
class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode):
755
    test_case: TestCase
756
    device_type: str
757
    dtype: torch.dtype
758

759
    def __init__(self, test_case, *, device, dtype, inplace):
760
        self.test_case = test_case
761
        self.device_type = torch.device(device).type
762
        self.dtype = dtype
763
        self.inplace = inplace
764

765
    def __torch_function__(self, func, types, args=(), kwargs=None):
766
        kwargs = kwargs or {}
767

768
        if (
769
            torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or
770
            # meta converter doesn't work correctly when no_dispatch() is on, so
771
            # skip running the crossref test in this case
772
            torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python)
773
        ):
774
            return func(*args, **kwargs)
775

776
        if self.dtype in meta_function_skips.get(func, set()):
777
            test_expect = TestExpect.SKIP
778
        elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()):
779
            test_expect = TestExpect.SKIP
780
        elif self.dtype in meta_function_expected_failures.get(func, set()):
781
            test_expect = TestExpect.XFAILURE
782
        elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()):
783
            test_expect = TestExpect.XFAILURE
784
        elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs):
785
            test_expect = TestExpect.XFAILURE
786
        elif not self.inplace and \
787
                self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()):
788
            test_expect = TestExpect.XFAILURE
789
        else:
790
            test_expect = TestExpect.SUCCESS
791

792
        return run_meta_crossref(
793
            self.test_case, test_expect, func, args,
794
            kwargs, dtype=self.dtype, device_type=self.device_type, run_symbolic_meta=False
795
        )
796

797
# these always fail
798
meta_dispatch_expected_failures = {
799
    aten.allclose.default: {f16, bf16, f32, f64, c64, c128},  # NotImplementedError: 'aten::_local_scalar_dense'
800
    aten.geqrf.default : {c64, c128, f64, f32},
801
    aten.linalg_eig.default : {c64, c128, f64, f32},
802
    aten.linalg_lstsq.default : {c64, c128, f64, f32},
803
    aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
804
    aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
805
    aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
806
    aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
807
    aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
808
    aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
809
    aten._ctc_loss.Tensor : {f32, f64},  # Shape of second output depends on data.
810
    aten._histogramdd_bin_edges.default : {f32, f64},
811
    aten._histogramdd_from_bin_cts.default : {f32, f64},
812
    aten._histogramdd_from_bin_tensors.default : {f32, f64},
813
    aten._local_scalar_dense.default : {c32, c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
814
    aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
815
    aten.bincount.default : {i64, i8, i32, i16, u8},
816
    aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
817
    aten.histc.default : {bf16, f32, f64},
818
    aten.histc.out : {bf16, f32, f64},
819
    aten.histogram.bin_ct : {f32, f64},
820
    aten.histogram.bins_tensor : {f32, f64},
821
    aten.kthvalue.default : {i8, f64, i64, f16, bf16, f32, i32, i16, u8},
822
    aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
823
    aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
824
    aten.upsample_nearest3d.vec : {bf16, f32, f64, u8},
825

826
}
827

828
# these sometimes pass and sometimes fail
829
meta_dispatch_skips = {
830
    aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},  # at::nonzero doesn't have a Meta function
831
    aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},
832
    aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
833
    aten.addbmm_.default: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
834
}
835

836
# For CompositeImplicitAutograd functions that fail before hitting the Mode
837
meta_dispatch_early_skips = set({
838
    torch.Tensor.float_power_,
839
    # Errors out in one of the tests, while ProxyTensor passes...
840
    torch.Tensor.cumprod_,
841
    torch.Tensor.cumsum_,
842
})
843

844
meta_inplace_skips = set({
845
    # Errors out in one of the tests, while ProxyTensor passes...
846
    torch.Tensor.cumprod_,
847
    torch.Tensor.cumsum_,
848
})
849

850
meta_dispatch_device_expected_failures = defaultdict(dict)
851
meta_dispatch_device_skips = defaultdict(dict)
852

853
meta_dispatch_device_expected_failures['cpu'] = {
854
    aten.native_batch_norm.default: {bf16, f16},
855
    aten._native_batch_norm_legit.default: {bf16, f16},
856
    aten._native_batch_norm_legit.no_stats: {bf16, f16},
857
    aten.native_layer_norm.default: {bf16, f16},
858
    aten.histc.default: {f16},
859
    aten.histc.out: {f16},
860
}
861

862
meta_dispatch_device_expected_failures['cuda'] = {
863
    aten._unique2.default: {f16},  # aten::_unique2
864
    aten._use_cudnn_ctc_loss.default: {f32, f64},  # aten::_use_cudnn_ctc_loss
865
    aten._use_cudnn_ctc_loss.Tensor: {f32, f64},  # aten::_use_cudnn_ctc_loss.Tensor
866
    aten.cudnn_grid_sampler.default: {f16, f32, f64},  # aten::cudnn_grid_sampler
867
    aten.geqrf.default: {f32, f64},  # aten::geqrf
868
    aten.histc.default: {i16, i32, i64, i8},  # aten::histc
869
    aten.histc.out: {i16, i32, i64, i8},  # aten::histc.out
870
    aten.kthvalue.default: {f16},  # aten::kthvalue.values
871
    aten.linalg_eigvalsh.out: {f32, f64},  # aten::linalg_eigvalsh.out
872
    aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
873
    aten.log_sigmoid_forward.output : {bf16, f16, f64, f32},  # aten::log_sigmoid_forward.output
874
    aten.unique_consecutive.default: {f16},  # aten::unique_consecutive
875
    aten.unique_dim.default: {f16},  # aten::unique_dim
876
    aten.upsample_nearest3d.vec: {f16},  # aten::upsample_nearest3d.vec
877
}
878

879
meta_dispatch_device_skips['cpu'] = {
880
    aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
881
    aten.native_batch_norm.default: {f32, f64},
882
    aten._native_batch_norm_legit.default: {f32, f64},
883
    aten._native_batch_norm_legit.no_stats: {f32, f64},
884

885
    # If the computation dtype is different from the input
886
    # dtype this will fail. CPU execution may also have a
887
    # a different output from other devices.
888
    aten.native_batch_norm.out: {bf16, f16, f32, f64}
889
}
890

891
meta_dispatch_device_skips['cuda'] = {
892
    aten._conj.default: {c32, f16},  # file issue
893
    aten._linalg_svd.default: {c64, c128},  # aten::linalg_eigvalsh.out
894
    aten.cudnn_batch_norm.default: {f32, f64},
895
    aten.log_softmax.int : {c32, c64},
896
    aten.softmax.int : {c32, c64},
897
    aten.softmax.int : {c32, c64},
898

899
    # ROCm stuff; technically this should be expected failure but it's
900
    # not worth it; these should get unified anyway
901
    aten.miopen_batch_norm.default: {f32},
902
}
903

904
def get_strided_args(args):
905

906
    def get_strided_variants(t, include_storage_offset=False):
907
        variants = []
908

909
        # contiguous
910
        variants.append(t)
911

912
        # transposed
913
        if t.ndim > 1:
914
            perm = list(reversed(range(t.ndim)))
915
            transposed = torch.empty(
916
                t.shape[::-1], device=t.device, dtype=t.dtype, requires_grad=t.requires_grad
917
            ).permute(perm).copy_(t)
918
            variants.append(transposed)
919

920
        # nondense
921
        if t.ndim > 0:
922
            nondense = torch.repeat_interleave(t, 2, dim=-1)[..., ::2]
923
            variants.append(nondense)
924

925
        # channel_last
926
        if t.ndim == 4:
927
            variants.append(t.contiguous(memory_format=torch.channels_last))
928

929
        # channel_last_3d
930
        if t.ndim == 5:
931
            variants.append(t.contiguous(memory_format=torch.channels_last_3d))
932

933
        # storage_offset
934
        if include_storage_offset:
935
            buffer = torch.empty(t.numel() + 1, device=t.device, dtype=t.dtype, requires_grad=t.requires_grad)
936
            buffer = buffer.as_strided(t.shape, t.stride(), storage_offset=1)
937
            buffer.copy_(t)
938
            variants.append(buffer)
939

940
        return variants
941

942
    strided_args = []
943
    for arg in args:
944
        if isinstance(arg, torch.Tensor) and not arg.is_sparse_csr and arg.is_contiguous():
945
            strided_arg_variants = get_strided_variants(arg)
946
        else:
947
            strided_arg_variants = [arg]
948
        strided_args.append(strided_arg_variants)
949

950
    yield from itertools.product(*strided_args)
951

952
class MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
953
    test_case: TestCase
954
    device: torch.device
955
    dtype: torch.dtype
956
    aten_olp_no_out_overload: set = set()
957

958
    def __init__(self, test_case, *, device, dtype, symbolic_meta: bool, inplace: bool, supports_out: bool):
959
        self.test_case = test_case
960
        # save TLS
961
        self.precision = test_case.precision
962
        self.rel_tol = test_case.rel_tol
963
        self.device_type = torch.device(device).type
964
        self.dtype = dtype
965
        self.symbolic_meta = symbolic_meta
966
        self.inplace = inplace
967
        self.supports_out = supports_out
968

969
    @staticmethod
970
    def try_resolve_aten_out_overload(ol, args, kwargs, num_outputs):
971

972
        ol_args = ol._schema.arguments
973
        olp: OpOverloadPacket = ol._overloadpacket
974

975
        if olp in MetaCrossRefDispatchMode.aten_olp_no_out_overload:
976
            return (None, None, None)
977

978
        candidate_ols = []
979
        for candidate_ol_name in olp.overloads():
980
            candidate_ol = getattr(olp, candidate_ol_name)
981
            if any(arg.is_out for arg in candidate_ol._schema.arguments):
982
                candidate_ols.append(candidate_ol)
983

984
        if not candidate_ols:
985
            MetaCrossRefDispatchMode.aten_olp_no_out_overload.add(olp)
986
            return (None, None, None)
987

988
        # Now match based on args, kwargs and number of required outputs
989
        candidate_ol: OpOverload = None
990
        for candidate_ol in candidate_ols:
991
            candidate_ol_args = candidate_ol._schema.arguments
992

993
            if (len(args) >= len(candidate_ol_args)):
994
                continue
995

996
            # Positional arguments must have the same type
997
            if not all(
998
                ol_args[pos_arg_ind].type == candidate_ol_args[pos_arg_ind].type
999
                for pos_arg_ind in range(len(args))
1000
            ):
1001
                continue
1002

1003
            # Number of outputs must match
1004
            candidate_out_names = [out_arg.name for out_arg in candidate_ol_args[-num_outputs:] if out_arg.is_out]
1005
            if len(candidate_out_names) != num_outputs:
1006
                continue
1007

1008
            # Now try and match kwargs. Just need to ensure that the
1009
            # remaining kwargs allow an out overload to be called. For example
1010
            # we can throw away parameters like `dtype` that may be passed to the
1011
            # functional version of the op since the `dtype` will already be present
1012
            # in the `out` argument
1013
            new_kwargs = {}
1014
            kwargs_match = True
1015
            for arg in candidate_ol_args[len(args):-num_outputs]:
1016
                if arg.name not in kwargs:
1017
                    if arg.has_default_value():
1018
                        new_kwargs[arg.name] = arg.default_value
1019
                    elif isinstance(arg.type, torch.OptionalType):
1020
                        if isinstance(arg.type.getElementType(), torch.BoolType):
1021
                            new_kwargs[arg.name] = False
1022
                        else:
1023
                            new_kwargs[arg.name] = None
1024
                    else:
1025
                        kwargs_match = False
1026
                        break
1027
                else:
1028
                    new_kwargs[arg.name] = kwargs[arg.name]
1029

1030
            if kwargs_match:
1031
                return candidate_ol, candidate_out_names, new_kwargs
1032

1033
        return None, None, None
1034

1035
    def _get_expected_test_result(self, func: OpOverload):
1036
        if self.dtype in meta_dispatch_skips.get(func, set()):
1037
            test_expect = TestExpect.SKIP
1038
        elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()):
1039
            test_expect = TestExpect.SKIP
1040
        elif self.dtype in meta_dispatch_expected_failures.get(func, set()):
1041
            test_expect = TestExpect.XFAILURE
1042
        elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()):
1043
            test_expect = TestExpect.XFAILURE
1044
        else:
1045
            test_expect = TestExpect.SUCCESS
1046
        return test_expect
1047

1048
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1049
        kwargs = kwargs or {}
1050
        self.test_case.precision = self.precision
1051
        self.test_case.rel_tol = self.rel_tol
1052

1053
        test_expect = self._get_expected_test_result(func)
1054

1055
        expected = run_meta_crossref(
1056
            self.test_case,
1057
            test_expect,
1058
            func,
1059
            args,
1060
            kwargs,
1061
            dtype=self.dtype,
1062
            device_type=self.device_type,
1063
            run_symbolic_meta=self.symbolic_meta,
1064
        )
1065

1066
        # This is to test torch ops that do not have an out parameter but have
1067
        # aten op overloads that have out parameters. Additionally, Python decompositions
1068
        # may register OpOverloadPacket's so decompositions need to be tested
1069
        # to ensure all OpOverloads still function for the Meta key (e.g. if a python decomposition
1070
        # is registered for an aten op aten.foo with overloads [default, out], the python
1071
        # function needs to support receiving `out` arguments)
1072
        if (
1073
            not self.inplace and
1074
            not self.supports_out and
1075
            test_expect == TestExpect.SUCCESS and
1076
            (torch.is_tensor(expected) or isinstance(expected, Iterable))
1077
        ):
1078

1079
            # check to see if there is a potential out overload
1080
            num_outputs = 1 if torch.is_tensor(expected) else len(expected)
1081
            func_out_overload, out_param_names, kwargs = self.try_resolve_aten_out_overload(func, args, kwargs, num_outputs)
1082

1083
            if func_out_overload:
1084

1085
                if num_outputs == 1:
1086
                    kwargs[out_param_names[0]] = expected
1087
                else:
1088
                    for ind, out_param_name in enumerate(out_param_names):
1089
                        kwargs[out_param_name] = expected[ind]
1090

1091
                test_expect = self._get_expected_test_result(func_out_overload)
1092

1093
                run_meta_crossref(
1094
                    self.test_case,
1095
                    test_expect,
1096
                    func_out_overload,
1097
                    args,
1098
                    kwargs,
1099
                    dtype=self.dtype,
1100
                    device_type=self.device_type,
1101
                    run_symbolic_meta=self.symbolic_meta,
1102
                )
1103

1104
        return expected
1105

1106
# NB: we're running these tests only on CUDA because there are some
1107
# inconsistencies between CUDA and CPU, and running on CUDA makes it easier
1108
# to ignore the CPU case when inconsistencies arise.  Ideally we deal
1109
# with the inconsistencies but this takes time.
1110
@unMarkDynamoStrictTest
1111
class TestMeta(TestCase):
1112
    # Copies inputs to inplace operations to avoid inplace modifications
1113
    #   to leaves requiring gradient
1114
    def _get_safe_inplace(self, inplace_variant):
1115
        @wraps(inplace_variant)
1116
        def _fn(t, *args, **kwargs):
1117
            if isinstance(t, list):
1118
                return inplace_variant([x.clone() for x in t], *args, **kwargs)
1119
            else:
1120
                return inplace_variant(t.clone(), *args, **kwargs)
1121

1122
        return _fn
1123

1124
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1125
    @skipIfCrossRef
1126
    @suppress_warnings
1127
    @ops(itertools.chain(op_db, foreach_op_db))
1128
    def test_meta_outplace(self, device, dtype, op):
1129
        skip_op_names = (
1130
            "fft.ihfft",
1131
            "fft.ihfft2",
1132
            "linalg.lu_solve",
1133
        )
1134
        if TEST_WITH_TORCHDYNAMO and op.name in skip_op_names:
1135
            raise unittest.SkipTest("flaky")
1136
        # run the OpInfo sample inputs, cross-referencing them with the
1137
        # meta implementation and check the results are the same.  All
1138
        # the heavy lifting happens in MetaCrossRefFunctionMode
1139
        func = op.get_op()
1140
        samples = op.sample_inputs(device, dtype, requires_grad=False)
1141
        for sample_input in samples:
1142
            args = [sample_input.input] + list(sample_input.args)
1143
            kwargs = sample_input.kwargs
1144
            with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=False):
1145
                expected = func(*args, **kwargs)
1146
                if isinstance(expected, torch.Tensor) and op.supports_out:
1147
                    func(*args, **kwargs, out=expected)
1148

1149
            # Special test for functions taking "device" kwarg
1150
            # The crossref tests that replacing the device with "meta" works
1151
            # This part makes sure that *_like functions work well with a "meta"
1152
            # Tensor and their original device argument.
1153
            if "device" in kwargs and "_like" in op.name:
1154
                with torch.random.fork_rng():
1155
                    torch.manual_seed(123)
1156
                    ref = func(*args, **kwargs)
1157

1158
                # *_like functions take a Tensor as first argument
1159
                assert isinstance(args[0], torch.Tensor)
1160
                with torch.random.fork_rng():
1161
                    torch.manual_seed(123)
1162
                    args[0] = args[0].to(device="meta")
1163
                    meta = func(*args, **kwargs)
1164

1165
                # empty_like is not deterministic
1166
                if op.name != "empty_like":
1167
                    self.assertEqual(ref, meta)
1168

1169
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1170
    @skipIfCrossRef
1171
    @suppress_warnings
1172
    @ops(itertools.chain(op_db, foreach_op_db))
1173
    def test_meta_inplace(self, device, dtype, op):
1174
        func = op.get_inplace()
1175
        if not func:
1176
            self.skipTest("No inplace variable for this op")
1177
        if op.promotes_int_to_float and not dtype.is_floating_point:
1178
            self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1179
        if func in meta_inplace_skips:
1180
            self.skipTest("Skipped")
1181
        func = self._get_safe_inplace(func)
1182
        samples = op.sample_inputs(device, dtype, requires_grad=False)
1183
        for sample_input in samples:
1184
            if sample_input.broadcasts_input:
1185
                continue
1186
            args = [sample_input.input] + list(sample_input.args)
1187
            kwargs = sample_input.kwargs
1188
            with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=True):
1189
                expected = func(*args, **kwargs)
1190

1191
    def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace, all_stride_variants=False):
1192
        if inplace:
1193
            func = op.get_inplace()
1194
            if not func:
1195
                self.skipTest("No inplace variable for this op")
1196
            if op.promotes_int_to_float and not dtype.is_floating_point:
1197
                self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1198
        else:
1199
            func = op.get_op()
1200

1201
        if func in meta_dispatch_early_skips:
1202
            self.skipTest("Function is in dispatch early skips")
1203

1204
        if inplace:
1205
            func = self._get_safe_inplace(func)
1206

1207
        samples = op.sample_inputs(device, dtype, requires_grad=False)
1208
        for sample_input in samples:
1209
            if inplace and sample_input.broadcasts_input:
1210
                continue
1211

1212
            sample_args = [sample_input.input] + list(sample_input.args)
1213
            kwargs = sample_input.kwargs
1214

1215
            if all_stride_variants and sum(isinstance(arg, torch.Tensor) for arg in sample_args) <= 5:
1216
                # test inputs <= 5 tensors to avoid combinatorial explosion
1217
                strided_args = get_strided_args(sample_args)
1218
            else:
1219
                strided_args = [sample_args]
1220

1221
            for args in strided_args:
1222
                with MetaCrossRefDispatchMode.push(
1223
                    self, dtype=dtype, device=device,
1224
                    symbolic_meta=symbolic_meta, inplace=inplace,
1225
                     supports_out=op.supports_out):
1226
                    expected = func(*args, **kwargs)
1227

1228
                    if not inplace and isinstance(expected, torch.Tensor) and op.supports_out:
1229
                        func(*args, **kwargs, out=expected)
1230

1231

1232
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1233
    @skipIfCrossRef
1234
    @suppress_warnings
1235
    @ops(itertools.chain(op_db, foreach_op_db))
1236
    def test_dispatch_meta_outplace(self, device, dtype, op):
1237
        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=False)
1238

1239
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1240
    @skipIfCrossRef
1241
    @suppress_warnings
1242
    @ops(itertools.chain(op_db, foreach_op_db))
1243
    def test_dispatch_meta_inplace(self, device, dtype, op):
1244
        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=True)
1245

1246
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1247
    @skipIfCrossRef
1248
    @suppress_warnings
1249
    @ops(itertools.chain(op_db, foreach_op_db))
1250
    def test_dispatch_symbolic_meta_outplace(self, device, dtype, op):
1251
        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1252

1253

1254
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1255
    @skipIfCrossRef
1256
    @suppress_warnings
1257
    @ops(itertools.chain(op_db, foreach_op_db))
1258
    def test_dispatch_symbolic_meta_inplace(self, device, dtype, op):
1259
        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True)
1260

1261
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1262
    @skipIfCrossRef
1263
    @suppress_warnings
1264
    # only test one dtype, as output stride behavior is the same for all dtypes
1265
    @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1266
    # Only test on CUDA, as CUDA kernel's stride is the reference
1267
    @onlyCUDA
1268
    def test_dispatch_symbolic_meta_outplace_all_strides(self, device, dtype, op):
1269
        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False, all_stride_variants=True)
1270

1271
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1272
    @skipIfCrossRef
1273
    @suppress_warnings
1274
    # only test one dtype, as output stride behavior is the same for all dtypes
1275
    @ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1276
    # Only test on CUDA, as CUDA kernel's stride is the reference
1277
    @onlyCUDA
1278
    def test_dispatch_symbolic_meta_inplace_all_strides(self, device, dtype, op):
1279
        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True, all_stride_variants=True)
1280

1281
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1282
    @skipIfCrossRef
1283
    @suppress_warnings
1284
    # only test one dtype, as output stride behavior is the same for all dtypes
1285
    @ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
1286
    # Only test on CUDA, as CUDA kernel's stride is the reference
1287
    @onlyCUDA
1288
    def test_binary_ufuncs_mixed_dtype(self, device, dtype, op):
1289
        make_arg = partial(
1290
            make_tensor,
1291
            device=device,
1292
        )
1293

1294
        def sample_input(op, device, dtype, requires_grad, **kwargs):
1295
            yield SampleInput(
1296
                make_arg((S,), dtype=dtype), make_arg((S,), dtype=torch.float16)
1297
            )
1298

1299
        op = copy.copy(op)
1300
        op.sample_inputs_func = sample_input
1301

1302
        self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1303

1304

1305
    def test_empty_quantized(self):
1306
        r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
1307
        self.assertEqual(r.device.type, 'meta')
1308

1309
    def test_nan_to_num(self):
1310
        t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta')
1311
        r = t.nan_to_num()
1312
        self.assertEqual(r.device.type, 'meta')
1313

1314
    def test_inplace_masked_fill_error(self):
1315
        t = torch.randn(3, 3, device='meta')
1316
        with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1317
            t.masked_fill_((t > 0).unsqueeze(0), 0.1)
1318

1319
    def test_inplace_bin_ops_error(self):
1320
        t = torch.randn(3, 3, device='meta')
1321
        for op in (torch.Tensor.add_, torch.Tensor.sub_, torch.Tensor.mul_, torch.Tensor.div_,
1322
                   torch.Tensor.logical_and_, torch.Tensor.logical_or_, torch.Tensor.logical_xor_):
1323
            with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1324
                op(t, t.clone().unsqueeze(0))
1325

1326
    @onlyCPU
1327
    def test_meta_autograd_no_error(self):
1328
        with torch.library._scoped_library("meta_test", "DEF") as lib:
1329
            with torch.library._scoped_library("meta_test", "IMPL", "CPU") as impl_cpu:
1330
                with torch.library._scoped_library("meta_test", "IMPL", "Meta") as impl_meta:
1331
                    def foo_impl(x):
1332
                        return x + 1
1333

1334
                    lib.define("foo(Tensor a) -> Tensor")
1335
                    impl_meta.impl("foo", foo_impl)
1336
                    impl_cpu.impl("foo", foo_impl)
1337

1338
                    a = torch.ones(2, device='meta')
1339
                    # The point of the test is that this should not error:
1340
                    # We have a fallthrough kernel registered to the AutogradMeta
1341
                    # key for custom ops, so it's fine that `foo()` doesn't have
1342
                    # an autograd kernel.
1343
                    b = torch.ops.meta_test.foo.default(a)
1344

1345
    def test_huber_loss_backward(self):
1346
        inps = [torch.rand(2**52, device='meta') for _ in range(3)]
1347
        r = torch.ops.aten.huber_loss_backward(*inps, 0, 1.0)
1348
        self.assertEqual(r.device.type, 'meta')
1349
        self.assertEqual(r.shape, inps[0].shape)
1350

1351
    def _norm_backwards_test_helper(self, op, args, output_mask, expected_shapes):
1352

1353
        dtype = torch.float32
1354
        device = "meta"
1355

1356
        # test functional call
1357
        grads = op(*args, output_mask)
1358

1359
        def assertEqualShapes(res, exp):
1360
            self.assertIsNone(res) if exp is None else self.assertEqual(exp, res.shape)
1361

1362
        assertEqualShapes(grads[0], expected_shapes[0])
1363
        assertEqualShapes(grads[1], expected_shapes[1])
1364
        assertEqualShapes(grads[2], expected_shapes[2])
1365

1366
        out_kwargs = {
1367
            f"out{i}": torch.empty(0, device=device, dtype=dtype)
1368
            for i in range(len(output_mask))
1369
        }
1370

1371
        # test call with out parameters
1372
        grads = op(*args, output_mask, **out_kwargs)
1373

1374
        def assertEqualShapes(res, exp):
1375
            self.assertEqual(exp, res.shape) if exp is not None else True
1376

1377
        assertEqualShapes(out_kwargs["out0"], expected_shapes[0])
1378
        assertEqualShapes(out_kwargs["out1"], expected_shapes[1])
1379
        assertEqualShapes(out_kwargs["out2"], expected_shapes[2])
1380

1381
    @onlyCPU
1382
    @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1383
    def test_layer_norm_backward(self, output_mask):
1384
        from torch.testing._internal.common_methods_invocations import sample_inputs_layer_norm
1385

1386
        device = "meta"
1387
        dtype = torch.float32
1388

1389
        samples = sample_inputs_layer_norm(None, device, dtype, requires_grad=False)
1390

1391
        for sample in samples:
1392
            with self.subTest(sample=sample):
1393
                # handle optional weight and bias
1394
                if len(sample.args) != 3:
1395
                    sample.args = (*sample.args, *([None] * (3 - len(sample.args))))
1396

1397
                grad_out = torch.ones_like(sample.input)
1398
                normalized_shape, weight, bias = sample.args
1399
                ndims_after_reduction = sample.input.ndim - len(normalized_shape)
1400
                mean_shape = grad_out.shape[:ndims_after_reduction]
1401
                mean = torch.zeros(mean_shape, device=device, dtype=dtype)
1402
                rstd = torch.zeros(mean_shape, device=device, dtype=dtype)
1403

1404
                expected_shapes = (
1405
                    sample.input.shape if output_mask[0] else None,
1406
                    weight.shape if output_mask[1] and weight is not None else None,
1407
                    bias.shape if output_mask[2] and bias is not None else None)
1408

1409
                args = [grad_out, sample.input, normalized_shape, mean, rstd, weight, bias]
1410

1411
                self._norm_backwards_test_helper(torch.ops.aten.native_layer_norm_backward,
1412
                                                 args, output_mask, expected_shapes)
1413

1414
    @onlyCPU
1415
    @parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1416
    def test_group_norm_backward(self, output_mask):
1417
        from torch.testing._internal.common_methods_invocations import sample_inputs_group_norm
1418

1419
        # input, (args) num_groups, (kwargs) weight, bias eps
1420
        device = "meta"
1421
        dtype = torch.float32
1422
        samples = sample_inputs_group_norm(None, device, dtype, requires_grad=False)
1423

1424
        for sample in samples:
1425
            with self.subTest(sample=sample):
1426
                grad_out = torch.ones_like(sample.input)
1427
                N, C = sample.input.shape[:2]
1428
                HxW = torch.prod(torch.as_tensor(sample.input.shape[2:]), dtype=torch.int32).item()
1429
                group = sample.args[0]
1430
                mean = torch.zeros((N, group), device=device, dtype=dtype)
1431
                rstd = torch.zeros((N, group), device=device, dtype=dtype)
1432
                weight = torch.zeros((C), device=device, dtype=dtype)
1433

1434
                args = [grad_out, sample.input, mean, rstd, weight, N, C, HxW, group]
1435

1436
                expected_shapes = (
1437
                    sample.input.shape if output_mask[0] else None,
1438
                    weight.shape if output_mask[1] else None,
1439
                    weight.shape if output_mask[2] else None)
1440

1441
                # test functional call
1442
                self._norm_backwards_test_helper(torch.ops.aten.native_group_norm_backward,
1443
                                                 args, output_mask, expected_shapes)
1444

1445
    @onlyCPU
1446
    @parametrize("output_mask", list(itertools.product([True], [True, False], [True, False])))
1447
    def test_batch_norm_backward(self, output_mask):
1448
        from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm
1449

1450
        # input, (args) num_groups, (kwargs) weight, bias eps
1451
        device = "meta"
1452
        dtype = torch.float32
1453
        samples = sample_inputs_batch_norm(None, device, dtype, requires_grad=False)
1454

1455
        for sample in samples:
1456
            with self.subTest(sample=sample):
1457

1458
                if sample.input.dim() < 2:
1459
                    continue
1460

1461
                grad_out = torch.ones_like(sample.input)
1462
                running_mean, running_var, weight, bias = sample.args
1463
                train = sample.kwargs.get("training", True)
1464
                save_mean = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1465
                save_invstd = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1466

1467
                args = [grad_out, sample.input, weight, running_mean, running_var,
1468
                        save_mean, save_invstd, train, sample.kwargs.get("eps", 1e-5)]
1469

1470
                expected_shapes = (
1471
                    sample.input.shape,
1472
                    torch.Size([sample.input.shape[1]]) if output_mask[1] else None,
1473
                    torch.Size([sample.input.shape[1]]) if output_mask[2] else None)
1474

1475
                self._norm_backwards_test_helper(torch.ops.aten.native_batch_norm_backward,
1476
                                                 args, output_mask, expected_shapes)
1477

1478
    def test_fill__alias_relationship(self):
1479
        inps = torch.rand(2**52, device='meta')
1480
        r = torch.ops.aten.fill_(inps, 1.0)
1481
        # aten.fill_ returns an aliase
1482
        self.assertEqual(id(inps), id(r))
1483

1484
        # aten.fill returns a new tensor
1485
        r2 = torch.ops.aten.fill(inps, 1.0)
1486
        self.assertNotEqual(id(inps), id(r2))
1487

1488
    def test_meta__fused_moving_avg_obs_fq_helper(self, device):
1489
        from torch.ao.quantization import FusedMovingAvgObsFakeQuantize
1490
        to_meta = MetaConverter()
1491

1492
        x = torch.randn(5, 5, device=device)
1493
        running_min_op = torch.tensor(float("inf"), device=device)
1494
        running_max_op = torch.tensor(float("-inf"), device=device)
1495
        avg_const = 0.01
1496
        scale = torch.tensor([1.0], device=device)
1497
        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1498

1499
        mod = FusedMovingAvgObsFakeQuantize()
1500
        torch.ao.quantization.enable_fake_quant(mod)
1501
        torch.ao.quantization.enable_observer(mod)
1502
        mod.to(device)
1503

1504
        meta_x = to_meta(x)
1505

1506
        args = [
1507
            x,
1508
            mod.observer_enabled,
1509
            mod.fake_quant_enabled,
1510
            running_min_op,
1511
            running_max_op,
1512
            scale,
1513
            zero_point,
1514
            avg_const,
1515
            0,
1516
            255,
1517
            0,
1518
        ]
1519

1520
        meta_args = args.copy()
1521
        meta_args[0] = meta_x
1522

1523
        kwargss = [
1524
            {},
1525
            {"per_row_fake_quant": False, "symmetric_quant": False},
1526
            {"per_row_fake_quant": False, "symmetric_quant": True},
1527
        ]
1528

1529
        for kwargs in kwargss:
1530
            ref_out = aten._fused_moving_avg_obs_fq_helper.default(*args, **kwargs)
1531
            meta_out = aten._fused_moving_avg_obs_fq_helper.default(*meta_args, **kwargs)
1532

1533
            self.assertEqual(ref_out[0].size(), meta_out[0].size())
1534
            self.assertEqual(ref_out[0].stride(), meta_out[0].stride())
1535
            self.assertEqual(ref_out[1].size(), meta_out[1].size())
1536
            self.assertEqual(ref_out[1].stride(), meta_out[1].stride())
1537

1538
    def test_cdist_forward(self, device):
1539
        to_meta = MetaConverter()
1540
        x1 = torch.rand([3, 2], device=device)
1541
        x2 = torch.rand([2, 2], device=device)
1542
        p = 2.0
1543
        for compute_mode in (None, 1, 2):
1544
            ref = aten._cdist_forward.default(x1, x2, p, compute_mode)
1545
            res = aten._cdist_forward.default(to_meta(x1), to_meta(x2), p, compute_mode)
1546
            self.assertEqual(res.device.type, 'meta')
1547
            self.assertEqual(ref.shape, res.shape)
1548

1549
    def test_quantized_embedding_bag(self):
1550
        tab_shape = [8, 128]
1551
        emb_size, ind_len, off_len = tab_shape[0], 32, 33
1552
        f_table = torch.from_numpy((np.random.random_sample(tab_shape) + 1).astype(np.float32))
1553
        q_table = torch.ops.quantized.embedding_bag_byte_prepack(f_table)
1554
        indices = torch.from_numpy(np.random.randint(low=0, high=emb_size, size=ind_len)).int()
1555
        max_length = len(indices) // (off_len - 1)
1556
        if max_length > 20:
1557
            max_length = 20
1558
        np_lengths = np.random.randint(0, max_length + 1, size=off_len - 1).astype(np.int32)
1559
        offsets = torch.cat([torch.zeros([1]), torch.cumsum(torch.from_numpy(np_lengths), 0)]).int()
1560

1561
        eb = torch.ops.quantized.embedding_bag_byte_rowwise_offsets(
1562
            q_table.to(device="meta"),
1563
            indices.to(device="meta"),
1564
            offsets.to(device="meta"),
1565
            mode=0,  # sum
1566
            per_sample_weights=None,
1567
            include_last_offset=True,
1568
        )
1569
        self.assertEqual(eb.shape, [32, 128])
1570
        self.assertEqual(eb.dtype, torch.float32)
1571
        self.assertEqual(eb.untyped_storage().data_ptr(), 0)
1572

1573
    # opinfo test is using aten.fill_, it's not testing aten.fill
1574
    @onlyCUDA
1575
    def test_fill_stride(self):
1576
        to_meta = MetaConverter()
1577
        sample_args = [torch.rand(2, 2, 2, 2), 1.0]
1578

1579
        for args in get_strided_args(sample_args):
1580
            meta_args = to_meta(args)
1581
            ref_out = torch.ops.aten.fill(*args)
1582
            meta_out = torch.ops.aten.fill(*meta_args)
1583
            self.assertEqual(ref_out.size(), meta_out.size())
1584
            self.assertEqual(ref_out.stride(), meta_out.stride())
1585

1586

1587
    def test_map_location_deserialize(self):
1588
        import io
1589

1590
        t = torch.rand(10)
1591
        b = io.BytesIO()
1592

1593
        torch.save(t, b)
1594
        b.seek(0)
1595
        r = torch.load(b, map_location=torch.device("meta"))
1596
        self.assertEqual(r.device.type, 'meta')
1597
        self.assertEqual(r.shape, t.shape)
1598
        self.assertEqual(r.dtype, t.dtype)
1599
        self.assertEqual(r.storage().data_ptr(), 0)
1600

1601
    def test_embedding_bag_byte_prepack(self):
1602
        batch_size = 10
1603
        num_embeddings = 80
1604
        embedding_dim = [128, 256, 512]
1605
        res_shape = [[batch_size, num_embeddings, ed + 8] for ed in embedding_dim]
1606
        for ed, rs in zip(embedding_dim, res_shape):
1607
            weight = torch.randn(batch_size, num_embeddings, ed, dtype=torch.float32)
1608
            res = torch.ops.quantized.embedding_bag_byte_prepack(weight.to(device="meta"))
1609
            self.assertEqual(res.shape, rs)
1610
            self.assertEqual(res.dtype, torch.float32)
1611
            self.assertEqual(res.untyped_storage().data_ptr(), 0)
1612

1613
    def test_embedding_bag_byte_unpack(self):
1614
        batch_size = 10
1615
        num_embeddings = 80
1616
        embedding_dim = [128, 256, 512]
1617
        res_shape = [[batch_size, num_embeddings, ed] for ed in embedding_dim]
1618
        for ed, rs in zip(embedding_dim, res_shape):
1619
            packed_weight = torch.randn(batch_size, num_embeddings, ed + 8, dtype=torch.float32)
1620
            res = torch.ops.quantized.embedding_bag_byte_unpack(packed_weight.to(device="meta"))
1621
            self.assertEqual(res.shape, rs)
1622
            self.assertEqual(res.dtype, torch.float32)
1623
            self.assertEqual(res.untyped_storage().data_ptr(), 0)
1624

1625
    def test_index_select_out(self):
1626
        def f():
1627
            input = torch.randn([8, 16], device='meta')
1628
            index = torch.tensor([2, 1, 6, 7, 3, 1, 7, 5, 6, 7], device='meta')
1629
            out = torch.empty([10, 16], device='meta')
1630
            return torch.index_select(input=input, dim=0, index=index, out=out)
1631
        with enable_python_dispatcher():
1632
            out = f()
1633
            self.assertEqual(out.shape, [10, 16])
1634

1635
instantiate_device_type_tests(TestMeta, globals())
1636

1637
def print_op_str_if_not_supported(op_str):
1638
    op = OperatorName.parse(op_str)
1639
    packet = getattr(torch.ops.aten, str(op.name))
1640
    overload = getattr(packet, op.overload_name if op.overload_name else "default")
1641
    if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]):
1642
        print(f"{overload}  # SKIP")
1643
    if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]):
1644
        print(overload)
1645

1646

1647
if __name__ == "__main__":
1648
    COMPARE_XLA = os.getenv('PYTORCH_COMPARE_XLA', None)
1649
    if COMPARE_XLA is not None:
1650
        with open(COMPARE_XLA) as f:
1651
            d = yaml.load(f, Loader=YamlLoader)
1652
            ops = d.get("full_codegen", []) + d.get("supported", []) + d.get("autograd", [])
1653
            for op_str in ops:
1654
                print_op_str_if_not_supported(op_str)
1655
        sys.exit(0)
1656

1657
    COMPARE_TEXT = os.getenv('PYTORCH_COMPARE_TEXT', None)
1658
    if COMPARE_TEXT is not None:
1659
        with open(COMPARE_TEXT) as f:
1660
            for op_str in f:
1661
                print_op_str_if_not_supported(op_str.strip())
1662
        sys.exit(0)
1663

1664
    run_tests()
1665

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.