8
from torch.overrides import resolve_name
9
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
10
from torch.utils import _pytree as pytree
11
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
12
import torch.utils._python_dispatch
13
from torch._dispatch.python import enable_python_dispatcher
14
from torch._ops import OpOverload, OpOverloadPacket
15
from torch.testing import make_tensor
16
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
17
from torch.testing._internal.common_utils import (
23
TEST_WITH_TORCHDYNAMO,
28
from torch.testing._internal.common_device_type import (
30
instantiate_device_type_tests,
35
from torch.testing._internal.common_methods_invocations import (
36
binary_ufuncs, op_db, foreach_unary_op_db, foreach_binary_op_db,
37
foreach_pointwise_op_db, foreach_reduce_op_db, foreach_other_op_db)
38
from torch.testing._internal.opinfo.core import S, SampleInput
39
from torchgen.yaml_utils import YamlLoader
40
from torchgen.model import OperatorName
47
from collections import defaultdict
48
from collections.abc import Iterable
52
from functools import partial, wraps
60
c128 = torch.complex128
70
foreach_binary_op_db +
71
foreach_pointwise_op_db +
72
foreach_reduce_op_db +
77
class TestMetaConverter(TestCase):
78
def assertSameVersionCounter(self, m1, m2):
82
self.assertEqual(m2._version, vc)
86
self.assertNotEqual(m1._version, vc)
87
self.assertEqual(m2._version, m1._version)
89
def assertMetadataMatches(self, m1, m2):
90
assert_metadata_eq(self.assertEqual, m1, m2)
92
def test_view_of_non_leaf(self):
93
x = torch.randn(4, requires_grad=True)
97
to_meta = MetaConverter()
102
self.assertTrue(m1._is_view())
103
self.assertFalse(m1._base.is_leaf)
105
self.assertIsNot(m1, m2)
106
self.assertMetadataMatches(m1, z1)
107
self.assertMetadataMatches(m2, z2)
108
self.assertSameVersionCounter(m1, m2)
110
def test_view_of_leaf(self):
111
x = torch.randn(4, requires_grad=True)
114
to_meta = MetaConverter()
119
self.assertTrue(m1._is_view())
120
self.assertTrue(m1._base.is_leaf)
122
self.assertIsNot(m1, m2)
123
self.assertMetadataMatches(m1, z1)
124
self.assertMetadataMatches(m2, z2)
125
self.assertSameVersionCounter(m1, m2)
127
def test_view_of_view_of_leaf(self):
130
y.requires_grad = True
133
to_meta = MetaConverter()
137
self.assertFalse(z.is_leaf)
139
self.assertMetadataMatches(mx, x)
140
self.assertMetadataMatches(mz, z)
143
x = torch.randn(4, requires_grad=True)
144
to_meta = MetaConverter()
148
self.assertTrue(m.is_leaf)
149
self.assertTrue(m.requires_grad)
151
self.assertMetadataMatches(m, x)
153
def test_non_leaf(self):
154
x = torch.randn(4, requires_grad=True)
156
to_meta = MetaConverter()
160
self.assertFalse(m.is_leaf)
161
self.assertTrue(m.requires_grad)
163
self.assertMetadataMatches(m, y)
165
def test_requires_grad_false(self):
166
x = torch.randn(4, requires_grad=False)
167
to_meta = MetaConverter()
171
self.assertFalse(m.requires_grad)
173
self.assertMetadataMatches(m, x)
175
def test_channels_last(self):
176
x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last)
177
to_meta = MetaConverter()
181
self.assertTrue(m.is_leaf)
183
self.assertMetadataMatches(m, x)
185
def test_channels_last_leaf(self):
186
x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
187
to_meta = MetaConverter()
191
self.assertTrue(m.requires_grad)
192
self.assertTrue(m.is_leaf)
194
self.assertMetadataMatches(m, x)
196
def test_channels_last_non_leaf(self):
197
x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
201
self.assertEqual(x.stride(), y.stride())
202
self.assertFalse(y.is_leaf)
204
to_meta = MetaConverter()
208
self.assertTrue(m.requires_grad)
209
self.assertFalse(m.is_leaf)
211
self.assertMetadataMatches(m, y)
216
torch.autograd.grad(loss, m)
218
def test_empty_strided_non_dense_leaf(self):
219
x = torch.empty_strided((2, 2), (4, 2), requires_grad=True)
221
to_meta = MetaConverter()
225
self.assertTrue(m.requires_grad)
226
self.assertTrue(m.is_leaf)
228
self.assertMetadataMatches(m, x)
230
def test_view_mutate(self):
234
to_meta = MetaConverter()
237
y.add_(torch.randn(2, 2, requires_grad=True))
238
m.add_(torch.randn(2, 2, device='meta', requires_grad=True))
240
def test_non_leaf_torture(self):
241
x = torch.empty(20, requires_grad=True)
242
with torch.no_grad():
243
x.set_(x.storage(), 10, (2,), (2,))
245
to_meta = MetaConverter()
249
self.assertTrue(m.requires_grad)
250
self.assertTrue(m.is_leaf)
252
self.assertMetadataMatches(m, x)
257
def test_view_as_real(self):
258
x = torch.randn(4, dtype=torch.complex64)
259
y = torch.view_as_real(x)
260
m = MetaConverter()(y)
261
self.assertMetadataMatches(m, y)
263
def test_complex_noncontiguous_bug(self):
264
x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :]
265
m = MetaConverter()(x)
266
self.assertMetadataMatches(m, x)
268
def test_view_as_complex(self):
269
x = torch.randn((4, 2), dtype=torch.float32)
270
y = torch.view_as_complex(x)
271
m = MetaConverter()(y)
272
self.assertMetadataMatches(m, y)
274
def test_view_dtype(self):
275
x = torch.randn(4, dtype=torch.float32)
276
y = x.view(dtype=torch.int32)
277
m = MetaConverter()(y)
278
self.assertMetadataMatches(m, y)
281
x = torch.randn(4, dtype=torch.complex64)
283
m = MetaConverter()(y)
284
self.assertMetadataMatches(m, y)
286
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
287
def test_weakref(self):
288
x = torch.randn(4, 4, 4)
293
self.assertEqual(len(m.tensor_memo), 1)
294
self.assertEqual(len(m.storage_memo), 1)
296
self.assertEqual(len(m.tensor_memo), 0)
297
m.check_for_expired_weak_storages()
298
self.assertEqual(len(m.storage_memo), 0)
302
li.append(torch.rand([i]))
304
self.assertEqual(len(m.tensor_memo), 4)
306
self.assertEqual(len(m.tensor_memo), 0)
307
m.check_for_expired_weak_storages()
308
self.assertEqual(len(m.storage_memo), 0)
310
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
311
def test_tensor_outlives_converter(self):
314
x = torch.randn([4, 4])
317
self.assertIs(ref(), None)
322
torch.Tensor.__getitem__,
326
aten.unsqueeze.default
329
CHECK_STRIDES_SKIPS = {
330
aten._conj_physical.default,
331
aten._fft_c2c.default,
332
aten._fft_c2r.default,
333
aten._fft_r2c.default,
334
aten._linalg_svd.default,
335
aten.binary_cross_entropy.default,
336
aten.complex.default,
338
aten.copysign.Tensor,
339
aten.div.Tensor_mode,
340
aten.floor_divide.default,
341
aten.heaviside.default,
344
aten.logaddexp.default,
345
aten.logical_and.default,
346
aten.logical_or.default,
347
aten.logical_xor.default,
350
aten.special_xlog1py.default,
352
aten.nll_loss2d_forward.default,
355
aten.convolution.default,
366
aten.linalg_lu_solve.out,
369
class CheckStrides(Enum):
374
def should_check_strides(func):
375
if func in CHECK_ALL_STRIDES:
376
return CheckStrides.ALL
377
if func in CHECK_STRIDES:
378
return CheckStrides.SIGNIFICANT
379
if func in CHECK_STRIDES_SKIPS:
380
return CheckStrides.NONE
381
if not isinstance(func, torch._ops.OpOverload):
382
return CheckStrides.NONE
384
if func.namespace == "prims":
385
return CheckStrides.SIGNIFICANT
388
if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info):
389
return CheckStrides.SIGNIFICANT
391
return CheckStrides.SIGNIFICANT
393
def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
394
flat_meta_rs = pytree.tree_leaves(meta_rs)
395
flat_rs = pytree.tree_leaves(rs)
396
test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
397
for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs):
398
def test_assert(cond, msg):
400
raise RuntimeError(f"output {i}: {msg_callable(msg)}")
401
if not isinstance(r, torch.Tensor):
403
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
404
test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
405
test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
407
if should_check_strides(func) == CheckStrides.ALL:
408
same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
409
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
410
elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
411
same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
412
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
414
meta_r.storage_offset() == r.storage_offset(),
415
f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
416
test_assert(meta_r.requires_grad == r.requires_grad,
417
f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
418
if func not in CHECK_CONJ_SKIPS:
419
test_assert(meta_r.is_conj() == r.is_conj(),
420
f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
421
test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
440
COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1'
444
failed_reasons = defaultdict(set)
446
expected_failures = []
449
def fmt_dtypes(dtypes):
450
r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes))
453
for op, failed_dtypes in seen_failed.items():
454
ops = resolve_name(op)
455
succeeded_dtypes = seen_succeeded.get(op, set())
456
expected_failures_dtypes = failed_dtypes - succeeded_dtypes
457
skips_dtypes = failed_dtypes & succeeded_dtypes
459
if failed_reasons[op]:
460
reasons = " # " + ", ".join(sorted(failed_reasons[op]))
461
if expected_failures_dtypes:
462
expected_failures.append(f" {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}")
464
skips.append(f" {ops}: {fmt_dtypes(skips_dtypes)},")
465
expected_failures.sort()
469
expected_failures = {{
470
{nl.join(expected_failures)}
478
atexit.register(print_seen)
481
TestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP"))
486
def __init__(self, s):
493
if isinstance(t, torch.Tensor):
494
return Lit(f"{t} stride={t.stride()}")
498
return repr(tree_map(go, e))
500
def run_meta_crossref(
509
run_symbolic_meta: bool
511
to_meta = MetaConverter()
512
do_meta = test_expect is not TestExpect.SKIP
515
meta_args = tree_map(to_meta, args)
516
meta_kwargs = tree_map(to_meta, kwargs)
517
except Exception as e:
519
f"failed to convert args to meta; "
520
f"originally (*{args}, **{kwargs})") from e
522
rs = func(*args, **kwargs)
523
except Exception as e:
524
raise AssertionError("Original OpInfo is broken") from e
531
if do_meta and to_meta.successful():
533
if func is torch.tensor_split:
535
meta_args = (meta_args[0], args[1]) + meta_args[2:]
536
elif func is torch.Tensor.__getitem__:
538
assert len(args) == 2
539
flat_args = pytree.tree_leaves(args[1])
540
flat_meta_args, spec = tree_flatten(meta_args[1])
542
for a, ma in zip(flat_args, flat_meta_args):
543
flat_new_args.append(a if isinstance(a, torch.Tensor) and a.dtype in [torch.int8, torch.bool] else ma)
544
meta_args = (meta_args[0], tree_unflatten(flat_new_args, spec))
545
elif func in (torch.ops.aten.repeat_interleave.Tensor, torch.ops.aten.repeat_interleave.Tensor_out):
546
if kwargs.get("output_size", None) is None:
548
if func is torch.ops.aten.repeat_interleave.Tensor_out:
549
meta_kwargs["out"] = kwargs["out"]
550
elif func in (torch.ops.aten.index.Tensor, torch.ops.aten.index.Tensor_out):
554
for meta_index, real_index in zip(meta_args[1], args[1]):
555
if meta_index is not None and meta_index.dtype in [torch.int8, torch.bool]:
556
indices.append(real_index)
558
indices.append(meta_index)
559
meta_args = (meta_args[0], indices)
560
elif func is torch.nn.functional.ctc_loss and all([isinstance(args[2], list), isinstance(args[3], list)]):
563
test_expect = TestExpect.SUCCESS
565
if kwargs.get("device", None) is not None:
566
meta_kwargs["device"] = "meta"
573
with warnings.catch_warnings():
574
warnings.simplefilter("ignore")
575
if run_symbolic_meta:
580
with enable_python_dispatcher():
581
meta_rs = func(*meta_args, **meta_kwargs)
583
meta_rs = func(*meta_args, **meta_kwargs)
584
except Exception as e:
585
if test_expect is TestExpect.XFAILURE:
587
seen_failed.setdefault(func, set()).add(dtype)
588
if isinstance(e, NotImplementedError):
589
m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0])
591
failed_reasons[func].add(m.group(1))
594
raise RuntimeError(f"""\
595
failed to run: {resolve_name(func)}(
596
*{verbose_print(meta_args)},
597
**{verbose_print(meta_kwargs)}
602
assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\
603
meta disagrees with real impl:
605
{delim.join(map(verbose_print, meta_args))},
606
{delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())}
608
{verbose_print(meta_rs)}
613
if test_expect is TestExpect.XFAILURE:
615
seen_failed.setdefault(func, set()).add(dtype)
620
seen_succeeded.setdefault(func, set()).add(dtype)
621
if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
622
raise RuntimeError(f"unexpected success {resolve_name(func)} {meta_args} {meta_kwargs}")
628
RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ")
630
meta_function_expected_failures = {
631
torch.Tensor.to_sparse : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
632
torch.allclose : {f64, f16, c128, c64, bf16, f32},
633
torch.argwhere : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
634
torch.combinations : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
635
torch.corrcoef : {f64, i32, c128, i64, i16, u8, c64, bf16, f16, i8, f32},
636
torch.cov : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32, f16},
637
torch.functional.istft : {f64, c64, c128, f32},
638
torch.geqrf : {f64, c64, c128, f32},
639
torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
640
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
641
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
642
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
643
torch.bincount : {i32, i64, u8, i16, i8},
644
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
645
torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
646
torch.histc : {f64, f16, bf16, f32},
647
torch.histogram : {f64, f32},
648
torch.histogramdd : {f64, f32},
649
torch.kthvalue : {f64, i32, i64, u8, i16, f16, bf16, i8, f32},
650
torch.nn.functional.ctc_loss : {f64, f32},
651
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
652
torch.linalg.eig : {f64, f32, c128, c64},
653
torch.linalg.eigvals : {f64, f32, c128, c64},
654
torch.linalg.lstsq : {f64, f32, c128, c64},
657
meta_function_expected_failures_conditional = {
658
torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)),
662
# This is some sample code for how we could dump these dicts into YAML
663
# file for easier reading/writing
666
{resolve_name(k): [dtype_abbrs[d] for d in v]
667
for k, v in meta_function_expected_failures.items()}, default_flow_style=None))
672
meta_function_skips = {
673
torch.Tensor.__rmatmul__ : {bf16, c128, f64, f32, f16, c64},
674
torch.Tensor.matmul : {f64, f32, c128, c64},
675
torch.functional.atleast_2d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
676
torch.functional.atleast_3d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
677
torch.functional.cartesian_prod : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
678
torch.functional.einsum : {bf16, c128, f64, f32, f16, c64},
679
torch.inner : {f16, bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64},
680
torch.linalg.matrix_norm : {c128, f32, c64, f64},
681
torch.linalg.matrix_rank : {c128, c64},
682
torch.linalg.svd : {c128, c64},
683
torch.matmul : {bf16, c128, f64, f32, f16, c64},
684
torch.nanquantile : {f64, f32},
685
torch.narrow : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c32, c64},
686
torch.nn.functional.batch_norm : {f64, f32},
687
torch.nn.functional.binary_cross_entropy : {bf16, f64, f32, f16},
688
torch.nn.functional.dropout3d : {bf16, f64, f32, f16},
689
torch.nn.functional.local_response_norm : {bf16, f64, f32, f16},
690
torch.svd : {c128, c64},
691
torch.take_along_dim : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
692
torch.vstack : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
694
torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
695
torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128},
696
torch.nn.functional.cross_entropy : {bf16, f64, f32},
697
torch.nn.functional.nll_loss : {bf16, f64, f32},
698
torch.linalg.cond : {c128, c64, f32, f64},
699
torch.linalg.vecdot : {bf16, f64, f32, f16},
700
torch.empty : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
701
torch.Tensor.addbmm_: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
702
torch.nn.functional.one_hot : {i64},
706
meta_function_device_expected_failures = defaultdict(dict)
707
meta_function_device_expected_failures_only_outplace = defaultdict(dict)
708
meta_function_device_skips = defaultdict(dict)
710
meta_function_device_expected_failures['cpu'] = {
711
torch.native_batch_norm: {bf16, f16},
712
torch._native_batch_norm_legit: {bf16, f16},
713
torch.native_layer_norm: {bf16, f16},
716
meta_function_device_expected_failures['cuda'] = {
717
torch.corrcoef: {bf16, f16},
719
torch.functional.unique: {f16},
720
torch.functional.unique_consecutive: {f16},
721
torch.geqrf: {f32, f64},
722
torch.histc: {i16, i32, i64, i8},
723
torch.kthvalue: {f16},
726
meta_function_device_skips['cpu'] = {
727
torch.native_batch_norm: {f32, f64},
728
torch._native_batch_norm_legit: {f32, f64},
731
meta_function_device_skips['cuda'] = {
733
torch.linalg.matrix_rank: {f32, f64},
734
torch.linalg.svd: {f32, f64},
735
torch.nn.functional.cross_entropy: {f16},
736
torch.nn.functional.interpolate: {f16},
737
torch.nn.functional.nll_loss: {f16},
738
torch.svd: {f32, f64},
754
class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode):
759
def __init__(self, test_case, *, device, dtype, inplace):
760
self.test_case = test_case
761
self.device_type = torch.device(device).type
763
self.inplace = inplace
765
def __torch_function__(self, func, types, args=(), kwargs=None):
766
kwargs = kwargs or {}
769
torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or
772
torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python)
774
return func(*args, **kwargs)
776
if self.dtype in meta_function_skips.get(func, set()):
777
test_expect = TestExpect.SKIP
778
elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()):
779
test_expect = TestExpect.SKIP
780
elif self.dtype in meta_function_expected_failures.get(func, set()):
781
test_expect = TestExpect.XFAILURE
782
elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()):
783
test_expect = TestExpect.XFAILURE
784
elif meta_function_expected_failures_conditional.get(func, lambda *_, **__: False)(self.dtype, *args, **kwargs):
785
test_expect = TestExpect.XFAILURE
786
elif not self.inplace and \
787
self.dtype in meta_function_device_expected_failures_only_outplace[self.device_type].get(func, set()):
788
test_expect = TestExpect.XFAILURE
790
test_expect = TestExpect.SUCCESS
792
return run_meta_crossref(
793
self.test_case, test_expect, func, args,
794
kwargs, dtype=self.dtype, device_type=self.device_type, run_symbolic_meta=False
798
meta_dispatch_expected_failures = {
799
aten.allclose.default: {f16, bf16, f32, f64, c64, c128},
800
aten.geqrf.default : {c64, c128, f64, f32},
801
aten.linalg_eig.default : {c64, c128, f64, f32},
802
aten.linalg_lstsq.default : {c64, c128, f64, f32},
803
aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
804
aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
805
aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
806
aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
807
aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
808
aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
809
aten._ctc_loss.Tensor : {f32, f64},
810
aten._histogramdd_bin_edges.default : {f32, f64},
811
aten._histogramdd_from_bin_cts.default : {f32, f64},
812
aten._histogramdd_from_bin_tensors.default : {f32, f64},
813
aten._local_scalar_dense.default : {c32, c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
814
aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
815
aten.bincount.default : {i64, i8, i32, i16, u8},
816
aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
817
aten.histc.default : {bf16, f32, f64},
818
aten.histc.out : {bf16, f32, f64},
819
aten.histogram.bin_ct : {f32, f64},
820
aten.histogram.bins_tensor : {f32, f64},
821
aten.kthvalue.default : {i8, f64, i64, f16, bf16, f32, i32, i16, u8},
822
aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
823
aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
824
aten.upsample_nearest3d.vec : {bf16, f32, f64, u8},
829
meta_dispatch_skips = {
830
aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},
831
aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128},
832
aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
833
aten.addbmm_.default: {bf16, c128, c64, f32, f64, i16, i32, i64, i8, u8},
837
meta_dispatch_early_skips = set({
838
torch.Tensor.float_power_,
840
torch.Tensor.cumprod_,
841
torch.Tensor.cumsum_,
844
meta_inplace_skips = set({
846
torch.Tensor.cumprod_,
847
torch.Tensor.cumsum_,
850
meta_dispatch_device_expected_failures = defaultdict(dict)
851
meta_dispatch_device_skips = defaultdict(dict)
853
meta_dispatch_device_expected_failures['cpu'] = {
854
aten.native_batch_norm.default: {bf16, f16},
855
aten._native_batch_norm_legit.default: {bf16, f16},
856
aten._native_batch_norm_legit.no_stats: {bf16, f16},
857
aten.native_layer_norm.default: {bf16, f16},
858
aten.histc.default: {f16},
859
aten.histc.out: {f16},
862
meta_dispatch_device_expected_failures['cuda'] = {
863
aten._unique2.default: {f16},
864
aten._use_cudnn_ctc_loss.default: {f32, f64},
865
aten._use_cudnn_ctc_loss.Tensor: {f32, f64},
866
aten.cudnn_grid_sampler.default: {f16, f32, f64},
867
aten.geqrf.default: {f32, f64},
868
aten.histc.default: {i16, i32, i64, i8},
869
aten.histc.out: {i16, i32, i64, i8},
870
aten.kthvalue.default: {f16},
871
aten.linalg_eigvalsh.out: {f32, f64},
872
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
873
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32},
874
aten.unique_consecutive.default: {f16},
875
aten.unique_dim.default: {f16},
876
aten.upsample_nearest3d.vec: {f16},
879
meta_dispatch_device_skips['cpu'] = {
880
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
881
aten.native_batch_norm.default: {f32, f64},
882
aten._native_batch_norm_legit.default: {f32, f64},
883
aten._native_batch_norm_legit.no_stats: {f32, f64},
888
aten.native_batch_norm.out: {bf16, f16, f32, f64}
891
meta_dispatch_device_skips['cuda'] = {
892
aten._conj.default: {c32, f16},
893
aten._linalg_svd.default: {c64, c128},
894
aten.cudnn_batch_norm.default: {f32, f64},
895
aten.log_softmax.int : {c32, c64},
896
aten.softmax.int : {c32, c64},
897
aten.softmax.int : {c32, c64},
901
aten.miopen_batch_norm.default: {f32},
904
def get_strided_args(args):
906
def get_strided_variants(t, include_storage_offset=False):
914
perm = list(reversed(range(t.ndim)))
915
transposed = torch.empty(
916
t.shape[::-1], device=t.device, dtype=t.dtype, requires_grad=t.requires_grad
917
).permute(perm).copy_(t)
918
variants.append(transposed)
922
nondense = torch.repeat_interleave(t, 2, dim=-1)[..., ::2]
923
variants.append(nondense)
927
variants.append(t.contiguous(memory_format=torch.channels_last))
931
variants.append(t.contiguous(memory_format=torch.channels_last_3d))
934
if include_storage_offset:
935
buffer = torch.empty(t.numel() + 1, device=t.device, dtype=t.dtype, requires_grad=t.requires_grad)
936
buffer = buffer.as_strided(t.shape, t.stride(), storage_offset=1)
938
variants.append(buffer)
944
if isinstance(arg, torch.Tensor) and not arg.is_sparse_csr and arg.is_contiguous():
945
strided_arg_variants = get_strided_variants(arg)
947
strided_arg_variants = [arg]
948
strided_args.append(strided_arg_variants)
950
yield from itertools.product(*strided_args)
952
class MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
956
aten_olp_no_out_overload: set = set()
958
def __init__(self, test_case, *, device, dtype, symbolic_meta: bool, inplace: bool, supports_out: bool):
959
self.test_case = test_case
961
self.precision = test_case.precision
962
self.rel_tol = test_case.rel_tol
963
self.device_type = torch.device(device).type
965
self.symbolic_meta = symbolic_meta
966
self.inplace = inplace
967
self.supports_out = supports_out
970
def try_resolve_aten_out_overload(ol, args, kwargs, num_outputs):
972
ol_args = ol._schema.arguments
973
olp: OpOverloadPacket = ol._overloadpacket
975
if olp in MetaCrossRefDispatchMode.aten_olp_no_out_overload:
976
return (None, None, None)
979
for candidate_ol_name in olp.overloads():
980
candidate_ol = getattr(olp, candidate_ol_name)
981
if any(arg.is_out for arg in candidate_ol._schema.arguments):
982
candidate_ols.append(candidate_ol)
984
if not candidate_ols:
985
MetaCrossRefDispatchMode.aten_olp_no_out_overload.add(olp)
986
return (None, None, None)
989
candidate_ol: OpOverload = None
990
for candidate_ol in candidate_ols:
991
candidate_ol_args = candidate_ol._schema.arguments
993
if (len(args) >= len(candidate_ol_args)):
998
ol_args[pos_arg_ind].type == candidate_ol_args[pos_arg_ind].type
999
for pos_arg_ind in range(len(args))
1004
candidate_out_names = [out_arg.name for out_arg in candidate_ol_args[-num_outputs:] if out_arg.is_out]
1005
if len(candidate_out_names) != num_outputs:
1015
for arg in candidate_ol_args[len(args):-num_outputs]:
1016
if arg.name not in kwargs:
1017
if arg.has_default_value():
1018
new_kwargs[arg.name] = arg.default_value
1019
elif isinstance(arg.type, torch.OptionalType):
1020
if isinstance(arg.type.getElementType(), torch.BoolType):
1021
new_kwargs[arg.name] = False
1023
new_kwargs[arg.name] = None
1025
kwargs_match = False
1028
new_kwargs[arg.name] = kwargs[arg.name]
1031
return candidate_ol, candidate_out_names, new_kwargs
1033
return None, None, None
1035
def _get_expected_test_result(self, func: OpOverload):
1036
if self.dtype in meta_dispatch_skips.get(func, set()):
1037
test_expect = TestExpect.SKIP
1038
elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()):
1039
test_expect = TestExpect.SKIP
1040
elif self.dtype in meta_dispatch_expected_failures.get(func, set()):
1041
test_expect = TestExpect.XFAILURE
1042
elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()):
1043
test_expect = TestExpect.XFAILURE
1045
test_expect = TestExpect.SUCCESS
1048
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1049
kwargs = kwargs or {}
1050
self.test_case.precision = self.precision
1051
self.test_case.rel_tol = self.rel_tol
1053
test_expect = self._get_expected_test_result(func)
1055
expected = run_meta_crossref(
1062
device_type=self.device_type,
1063
run_symbolic_meta=self.symbolic_meta,
1073
not self.inplace and
1074
not self.supports_out and
1075
test_expect == TestExpect.SUCCESS and
1076
(torch.is_tensor(expected) or isinstance(expected, Iterable))
1080
num_outputs = 1 if torch.is_tensor(expected) else len(expected)
1081
func_out_overload, out_param_names, kwargs = self.try_resolve_aten_out_overload(func, args, kwargs, num_outputs)
1083
if func_out_overload:
1085
if num_outputs == 1:
1086
kwargs[out_param_names[0]] = expected
1088
for ind, out_param_name in enumerate(out_param_names):
1089
kwargs[out_param_name] = expected[ind]
1091
test_expect = self._get_expected_test_result(func_out_overload)
1100
device_type=self.device_type,
1101
run_symbolic_meta=self.symbolic_meta,
1110
@unMarkDynamoStrictTest
1111
class TestMeta(TestCase):
1114
def _get_safe_inplace(self, inplace_variant):
1115
@wraps(inplace_variant)
1116
def _fn(t, *args, **kwargs):
1117
if isinstance(t, list):
1118
return inplace_variant([x.clone() for x in t], *args, **kwargs)
1120
return inplace_variant(t.clone(), *args, **kwargs)
1124
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1127
@ops(itertools.chain(op_db, foreach_op_db))
1128
def test_meta_outplace(self, device, dtype, op):
1134
if TEST_WITH_TORCHDYNAMO and op.name in skip_op_names:
1135
raise unittest.SkipTest("flaky")
1140
samples = op.sample_inputs(device, dtype, requires_grad=False)
1141
for sample_input in samples:
1142
args = [sample_input.input] + list(sample_input.args)
1143
kwargs = sample_input.kwargs
1144
with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=False):
1145
expected = func(*args, **kwargs)
1146
if isinstance(expected, torch.Tensor) and op.supports_out:
1147
func(*args, **kwargs, out=expected)
1153
if "device" in kwargs and "_like" in op.name:
1154
with torch.random.fork_rng():
1155
torch.manual_seed(123)
1156
ref = func(*args, **kwargs)
1159
assert isinstance(args[0], torch.Tensor)
1160
with torch.random.fork_rng():
1161
torch.manual_seed(123)
1162
args[0] = args[0].to(device="meta")
1163
meta = func(*args, **kwargs)
1166
if op.name != "empty_like":
1167
self.assertEqual(ref, meta)
1169
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1172
@ops(itertools.chain(op_db, foreach_op_db))
1173
def test_meta_inplace(self, device, dtype, op):
1174
func = op.get_inplace()
1176
self.skipTest("No inplace variable for this op")
1177
if op.promotes_int_to_float and not dtype.is_floating_point:
1178
self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1179
if func in meta_inplace_skips:
1180
self.skipTest("Skipped")
1181
func = self._get_safe_inplace(func)
1182
samples = op.sample_inputs(device, dtype, requires_grad=False)
1183
for sample_input in samples:
1184
if sample_input.broadcasts_input:
1186
args = [sample_input.input] + list(sample_input.args)
1187
kwargs = sample_input.kwargs
1188
with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=True):
1189
expected = func(*args, **kwargs)
1191
def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace, all_stride_variants=False):
1193
func = op.get_inplace()
1195
self.skipTest("No inplace variable for this op")
1196
if op.promotes_int_to_float and not dtype.is_floating_point:
1197
self.skipTest("Op promotes to float, which is impossible for inplace with non-float input")
1201
if func in meta_dispatch_early_skips:
1202
self.skipTest("Function is in dispatch early skips")
1205
func = self._get_safe_inplace(func)
1207
samples = op.sample_inputs(device, dtype, requires_grad=False)
1208
for sample_input in samples:
1209
if inplace and sample_input.broadcasts_input:
1212
sample_args = [sample_input.input] + list(sample_input.args)
1213
kwargs = sample_input.kwargs
1215
if all_stride_variants and sum(isinstance(arg, torch.Tensor) for arg in sample_args) <= 5:
1217
strided_args = get_strided_args(sample_args)
1219
strided_args = [sample_args]
1221
for args in strided_args:
1222
with MetaCrossRefDispatchMode.push(
1223
self, dtype=dtype, device=device,
1224
symbolic_meta=symbolic_meta, inplace=inplace,
1225
supports_out=op.supports_out):
1226
expected = func(*args, **kwargs)
1228
if not inplace and isinstance(expected, torch.Tensor) and op.supports_out:
1229
func(*args, **kwargs, out=expected)
1232
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1235
@ops(itertools.chain(op_db, foreach_op_db))
1236
def test_dispatch_meta_outplace(self, device, dtype, op):
1237
self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=False)
1239
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1242
@ops(itertools.chain(op_db, foreach_op_db))
1243
def test_dispatch_meta_inplace(self, device, dtype, op):
1244
self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=False, inplace=True)
1246
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1249
@ops(itertools.chain(op_db, foreach_op_db))
1250
def test_dispatch_symbolic_meta_outplace(self, device, dtype, op):
1251
self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1254
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1257
@ops(itertools.chain(op_db, foreach_op_db))
1258
def test_dispatch_symbolic_meta_inplace(self, device, dtype, op):
1259
self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True)
1261
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1265
@ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1268
def test_dispatch_symbolic_meta_outplace_all_strides(self, device, dtype, op):
1269
self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False, all_stride_variants=True)
1271
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1275
@ops(itertools.chain(op_db, foreach_op_db), dtypes=OpDTypes.any_common_cpu_cuda_one)
1278
def test_dispatch_symbolic_meta_inplace_all_strides(self, device, dtype, op):
1279
self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True, all_stride_variants=True)
1281
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1285
@ops(binary_ufuncs, allowed_dtypes=(torch.float32,))
1288
def test_binary_ufuncs_mixed_dtype(self, device, dtype, op):
1294
def sample_input(op, device, dtype, requires_grad, **kwargs):
1296
make_arg((S,), dtype=dtype), make_arg((S,), dtype=torch.float16)
1300
op.sample_inputs_func = sample_input
1302
self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False)
1305
def test_empty_quantized(self):
1306
r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8)
1307
self.assertEqual(r.device.type, 'meta')
1309
def test_nan_to_num(self):
1310
t = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14], device='meta')
1312
self.assertEqual(r.device.type, 'meta')
1314
def test_inplace_masked_fill_error(self):
1315
t = torch.randn(3, 3, device='meta')
1316
with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1317
t.masked_fill_((t > 0).unsqueeze(0), 0.1)
1319
def test_inplace_bin_ops_error(self):
1320
t = torch.randn(3, 3, device='meta')
1321
for op in (torch.Tensor.add_, torch.Tensor.sub_, torch.Tensor.mul_, torch.Tensor.div_,
1322
torch.Tensor.logical_and_, torch.Tensor.logical_or_, torch.Tensor.logical_xor_):
1323
with self.assertRaisesRegex(RuntimeError, "doesn't match the broadcast"):
1324
op(t, t.clone().unsqueeze(0))
1327
def test_meta_autograd_no_error(self):
1328
with torch.library._scoped_library("meta_test", "DEF") as lib:
1329
with torch.library._scoped_library("meta_test", "IMPL", "CPU") as impl_cpu:
1330
with torch.library._scoped_library("meta_test", "IMPL", "Meta") as impl_meta:
1334
lib.define("foo(Tensor a) -> Tensor")
1335
impl_meta.impl("foo", foo_impl)
1336
impl_cpu.impl("foo", foo_impl)
1338
a = torch.ones(2, device='meta')
1343
b = torch.ops.meta_test.foo.default(a)
1345
def test_huber_loss_backward(self):
1346
inps = [torch.rand(2**52, device='meta') for _ in range(3)]
1347
r = torch.ops.aten.huber_loss_backward(*inps, 0, 1.0)
1348
self.assertEqual(r.device.type, 'meta')
1349
self.assertEqual(r.shape, inps[0].shape)
1351
def _norm_backwards_test_helper(self, op, args, output_mask, expected_shapes):
1353
dtype = torch.float32
1357
grads = op(*args, output_mask)
1359
def assertEqualShapes(res, exp):
1360
self.assertIsNone(res) if exp is None else self.assertEqual(exp, res.shape)
1362
assertEqualShapes(grads[0], expected_shapes[0])
1363
assertEqualShapes(grads[1], expected_shapes[1])
1364
assertEqualShapes(grads[2], expected_shapes[2])
1367
f"out{i}": torch.empty(0, device=device, dtype=dtype)
1368
for i in range(len(output_mask))
1372
grads = op(*args, output_mask, **out_kwargs)
1374
def assertEqualShapes(res, exp):
1375
self.assertEqual(exp, res.shape) if exp is not None else True
1377
assertEqualShapes(out_kwargs["out0"], expected_shapes[0])
1378
assertEqualShapes(out_kwargs["out1"], expected_shapes[1])
1379
assertEqualShapes(out_kwargs["out2"], expected_shapes[2])
1382
@parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1383
def test_layer_norm_backward(self, output_mask):
1384
from torch.testing._internal.common_methods_invocations import sample_inputs_layer_norm
1387
dtype = torch.float32
1389
samples = sample_inputs_layer_norm(None, device, dtype, requires_grad=False)
1391
for sample in samples:
1392
with self.subTest(sample=sample):
1394
if len(sample.args) != 3:
1395
sample.args = (*sample.args, *([None] * (3 - len(sample.args))))
1397
grad_out = torch.ones_like(sample.input)
1398
normalized_shape, weight, bias = sample.args
1399
ndims_after_reduction = sample.input.ndim - len(normalized_shape)
1400
mean_shape = grad_out.shape[:ndims_after_reduction]
1401
mean = torch.zeros(mean_shape, device=device, dtype=dtype)
1402
rstd = torch.zeros(mean_shape, device=device, dtype=dtype)
1405
sample.input.shape if output_mask[0] else None,
1406
weight.shape if output_mask[1] and weight is not None else None,
1407
bias.shape if output_mask[2] and bias is not None else None)
1409
args = [grad_out, sample.input, normalized_shape, mean, rstd, weight, bias]
1411
self._norm_backwards_test_helper(torch.ops.aten.native_layer_norm_backward,
1412
args, output_mask, expected_shapes)
1415
@parametrize("output_mask", list(itertools.product([True, False], [True, False], [True, False])))
1416
def test_group_norm_backward(self, output_mask):
1417
from torch.testing._internal.common_methods_invocations import sample_inputs_group_norm
1421
dtype = torch.float32
1422
samples = sample_inputs_group_norm(None, device, dtype, requires_grad=False)
1424
for sample in samples:
1425
with self.subTest(sample=sample):
1426
grad_out = torch.ones_like(sample.input)
1427
N, C = sample.input.shape[:2]
1428
HxW = torch.prod(torch.as_tensor(sample.input.shape[2:]), dtype=torch.int32).item()
1429
group = sample.args[0]
1430
mean = torch.zeros((N, group), device=device, dtype=dtype)
1431
rstd = torch.zeros((N, group), device=device, dtype=dtype)
1432
weight = torch.zeros((C), device=device, dtype=dtype)
1434
args = [grad_out, sample.input, mean, rstd, weight, N, C, HxW, group]
1437
sample.input.shape if output_mask[0] else None,
1438
weight.shape if output_mask[1] else None,
1439
weight.shape if output_mask[2] else None)
1442
self._norm_backwards_test_helper(torch.ops.aten.native_group_norm_backward,
1443
args, output_mask, expected_shapes)
1446
@parametrize("output_mask", list(itertools.product([True], [True, False], [True, False])))
1447
def test_batch_norm_backward(self, output_mask):
1448
from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm
1452
dtype = torch.float32
1453
samples = sample_inputs_batch_norm(None, device, dtype, requires_grad=False)
1455
for sample in samples:
1456
with self.subTest(sample=sample):
1458
if sample.input.dim() < 2:
1461
grad_out = torch.ones_like(sample.input)
1462
running_mean, running_var, weight, bias = sample.args
1463
train = sample.kwargs.get("training", True)
1464
save_mean = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1465
save_invstd = torch.zeros((sample.input.shape[1], ), device=device, dtype=dtype) if train else None
1467
args = [grad_out, sample.input, weight, running_mean, running_var,
1468
save_mean, save_invstd, train, sample.kwargs.get("eps", 1e-5)]
1472
torch.Size([sample.input.shape[1]]) if output_mask[1] else None,
1473
torch.Size([sample.input.shape[1]]) if output_mask[2] else None)
1475
self._norm_backwards_test_helper(torch.ops.aten.native_batch_norm_backward,
1476
args, output_mask, expected_shapes)
1478
def test_fill__alias_relationship(self):
1479
inps = torch.rand(2**52, device='meta')
1480
r = torch.ops.aten.fill_(inps, 1.0)
1482
self.assertEqual(id(inps), id(r))
1485
r2 = torch.ops.aten.fill(inps, 1.0)
1486
self.assertNotEqual(id(inps), id(r2))
1488
def test_meta__fused_moving_avg_obs_fq_helper(self, device):
1489
from torch.ao.quantization import FusedMovingAvgObsFakeQuantize
1490
to_meta = MetaConverter()
1492
x = torch.randn(5, 5, device=device)
1493
running_min_op = torch.tensor(float("inf"), device=device)
1494
running_max_op = torch.tensor(float("-inf"), device=device)
1496
scale = torch.tensor([1.0], device=device)
1497
zero_point = torch.tensor([0], dtype=torch.int, device=device)
1499
mod = FusedMovingAvgObsFakeQuantize()
1500
torch.ao.quantization.enable_fake_quant(mod)
1501
torch.ao.quantization.enable_observer(mod)
1508
mod.observer_enabled,
1509
mod.fake_quant_enabled,
1520
meta_args = args.copy()
1521
meta_args[0] = meta_x
1525
{"per_row_fake_quant": False, "symmetric_quant": False},
1526
{"per_row_fake_quant": False, "symmetric_quant": True},
1529
for kwargs in kwargss:
1530
ref_out = aten._fused_moving_avg_obs_fq_helper.default(*args, **kwargs)
1531
meta_out = aten._fused_moving_avg_obs_fq_helper.default(*meta_args, **kwargs)
1533
self.assertEqual(ref_out[0].size(), meta_out[0].size())
1534
self.assertEqual(ref_out[0].stride(), meta_out[0].stride())
1535
self.assertEqual(ref_out[1].size(), meta_out[1].size())
1536
self.assertEqual(ref_out[1].stride(), meta_out[1].stride())
1538
def test_cdist_forward(self, device):
1539
to_meta = MetaConverter()
1540
x1 = torch.rand([3, 2], device=device)
1541
x2 = torch.rand([2, 2], device=device)
1543
for compute_mode in (None, 1, 2):
1544
ref = aten._cdist_forward.default(x1, x2, p, compute_mode)
1545
res = aten._cdist_forward.default(to_meta(x1), to_meta(x2), p, compute_mode)
1546
self.assertEqual(res.device.type, 'meta')
1547
self.assertEqual(ref.shape, res.shape)
1549
def test_quantized_embedding_bag(self):
1550
tab_shape = [8, 128]
1551
emb_size, ind_len, off_len = tab_shape[0], 32, 33
1552
f_table = torch.from_numpy((np.random.random_sample(tab_shape) + 1).astype(np.float32))
1553
q_table = torch.ops.quantized.embedding_bag_byte_prepack(f_table)
1554
indices = torch.from_numpy(np.random.randint(low=0, high=emb_size, size=ind_len)).int()
1555
max_length = len(indices) // (off_len - 1)
1558
np_lengths = np.random.randint(0, max_length + 1, size=off_len - 1).astype(np.int32)
1559
offsets = torch.cat([torch.zeros([1]), torch.cumsum(torch.from_numpy(np_lengths), 0)]).int()
1561
eb = torch.ops.quantized.embedding_bag_byte_rowwise_offsets(
1562
q_table.to(device="meta"),
1563
indices.to(device="meta"),
1564
offsets.to(device="meta"),
1566
per_sample_weights=None,
1567
include_last_offset=True,
1569
self.assertEqual(eb.shape, [32, 128])
1570
self.assertEqual(eb.dtype, torch.float32)
1571
self.assertEqual(eb.untyped_storage().data_ptr(), 0)
1575
def test_fill_stride(self):
1576
to_meta = MetaConverter()
1577
sample_args = [torch.rand(2, 2, 2, 2), 1.0]
1579
for args in get_strided_args(sample_args):
1580
meta_args = to_meta(args)
1581
ref_out = torch.ops.aten.fill(*args)
1582
meta_out = torch.ops.aten.fill(*meta_args)
1583
self.assertEqual(ref_out.size(), meta_out.size())
1584
self.assertEqual(ref_out.stride(), meta_out.stride())
1587
def test_map_location_deserialize(self):
1595
r = torch.load(b, map_location=torch.device("meta"))
1596
self.assertEqual(r.device.type, 'meta')
1597
self.assertEqual(r.shape, t.shape)
1598
self.assertEqual(r.dtype, t.dtype)
1599
self.assertEqual(r.storage().data_ptr(), 0)
1601
def test_embedding_bag_byte_prepack(self):
1604
embedding_dim = [128, 256, 512]
1605
res_shape = [[batch_size, num_embeddings, ed + 8] for ed in embedding_dim]
1606
for ed, rs in zip(embedding_dim, res_shape):
1607
weight = torch.randn(batch_size, num_embeddings, ed, dtype=torch.float32)
1608
res = torch.ops.quantized.embedding_bag_byte_prepack(weight.to(device="meta"))
1609
self.assertEqual(res.shape, rs)
1610
self.assertEqual(res.dtype, torch.float32)
1611
self.assertEqual(res.untyped_storage().data_ptr(), 0)
1613
def test_embedding_bag_byte_unpack(self):
1616
embedding_dim = [128, 256, 512]
1617
res_shape = [[batch_size, num_embeddings, ed] for ed in embedding_dim]
1618
for ed, rs in zip(embedding_dim, res_shape):
1619
packed_weight = torch.randn(batch_size, num_embeddings, ed + 8, dtype=torch.float32)
1620
res = torch.ops.quantized.embedding_bag_byte_unpack(packed_weight.to(device="meta"))
1621
self.assertEqual(res.shape, rs)
1622
self.assertEqual(res.dtype, torch.float32)
1623
self.assertEqual(res.untyped_storage().data_ptr(), 0)
1625
def test_index_select_out(self):
1627
input = torch.randn([8, 16], device='meta')
1628
index = torch.tensor([2, 1, 6, 7, 3, 1, 7, 5, 6, 7], device='meta')
1629
out = torch.empty([10, 16], device='meta')
1630
return torch.index_select(input=input, dim=0, index=index, out=out)
1631
with enable_python_dispatcher():
1633
self.assertEqual(out.shape, [10, 16])
1635
instantiate_device_type_tests(TestMeta, globals())
1637
def print_op_str_if_not_supported(op_str):
1638
op = OperatorName.parse(op_str)
1639
packet = getattr(torch.ops.aten, str(op.name))
1640
overload = getattr(packet, op.overload_name if op.overload_name else "default")
1641
if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]):
1642
print(f"{overload} # SKIP")
1643
if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]):
1647
if __name__ == "__main__":
1648
COMPARE_XLA = os.getenv('PYTORCH_COMPARE_XLA', None)
1649
if COMPARE_XLA is not None:
1650
with open(COMPARE_XLA) as f:
1651
d = yaml.load(f, Loader=YamlLoader)
1652
ops = d.get("full_codegen", []) + d.get("supported", []) + d.get("autograd", [])
1654
print_op_str_if_not_supported(op_str)
1657
COMPARE_TEXT = os.getenv('PYTORCH_COMPARE_TEXT', None)
1658
if COMPARE_TEXT is not None:
1659
with open(COMPARE_TEXT) as f:
1661
print_op_str_if_not_supported(op_str.strip())