1
# Owner(s): ["module: custom-operators"]
3
from torch.testing._internal.common_utils import * # noqa: F403
4
from torch.testing._internal.common_device_type import * # noqa: F403
13
import torch._custom_ops as custom_ops
15
import torch.testing._internal.custom_op_db
16
import torch.testing._internal.optests as optests
17
import torch.utils.cpp_extension
18
from functorch import make_fx
19
from torch import Tensor
20
from torch._custom_op.impl import custom_op, CustomOp
21
from torch._utils_internal import get_file_path_2
22
from torch.testing._internal.common_cuda import TEST_CUDA
23
from torch.testing._internal.custom_op_db import custom_op_db
24
from typing import * # noqa: F403
27
class CustomOpTestCaseBase(TestCase):
28
test_ns = "_test_custom_op"
34
import torch._custom_op
36
keys = list(torch._custom_op.impl.global_registry.keys())
38
if not key.startswith(f"{self.test_ns}::"):
40
torch._custom_op.impl.global_registry[key]._destroy()
41
if hasattr(torch.ops, self.test_ns):
42
delattr(torch.ops, self.test_ns)
43
for lib in self.libraries:
48
return getattr(torch.ops, self.test_ns)
51
result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901
52
self.libraries.append(result)
55
def get_op(self, qualname):
56
return torch._custom_op.impl.get_op(qualname)
59
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
61
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
63
class TestCustomOpTesting(CustomOpTestCaseBase):
64
@parametrize("check_gradients", (False, "auto"))
65
@parametrize("dynamic", (True, False))
66
def test_aot_autograd_check_degenerate_cases(
67
self, device, dynamic, check_gradients
73
x = torch.randn(3, device=device)
74
optests.aot_autograd_check(
75
simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
78
def outputs_dont_require_grad(x):
82
y = torch.randn(3, device=device, requires_grad=True)
83
optests.aot_autograd_check(
84
simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
91
x = torch.randn(3, device=device, requires_grad=True)
92
y = torch.randn(3, device=device, requires_grad=False)
93
optests.aot_autograd_check(
94
no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
96
optests.aot_autograd_check(
97
no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
100
def test_incorrect_schema_mutation(self, device):
102
lib.define("foo(Tensor x) -> Tensor")
103
op = self.ns().foo.default
105
class Foo(torch.autograd.Function):
108
guard = torch._C._AutoDispatchBelowAutograd()
115
def backward(ctx, gx):
122
lib.impl("foo", Foo.apply, "Autograd")
123
lib.impl("foo", foo_impl, "CPU")
124
lib.impl("foo", foo_impl, "CUDA")
126
x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
127
with self.assertRaisesRegex(
128
optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
130
optests.opcheck(op, (x,), {})
132
def test_incorrect_schema_view(self, device):
134
lib.define("foo(Tensor x) -> Tensor")
135
op = self.ns().foo.default
137
class Foo(torch.autograd.Function):
140
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
141
with torch._C._AutoDispatchBelowAutograd():
142
with torch._C._ExcludeDispatchKeyGuard(
143
torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
148
def backward(ctx, gx):
157
lib.impl("foo", Foo.apply, "Autograd")
158
lib.impl("foo", foo_impl, "CPU")
159
lib.impl("foo", foo_meta, "Meta")
161
x = torch.tensor(3.14159 / 3, requires_grad=True)
162
with self.assertRaisesRegex(
163
optests.OpCheckError,
164
"Argument x is not defined to alias output but was aliasing",
166
optests.opcheck(op, (x,), {})
168
def test_missing_abstract_impl(self, device):
170
lib.define("foo(Tensor x) -> Tensor")
171
op = self.ns().foo.default
173
class Foo(torch.autograd.Function):
176
with torch._C._AutoDispatchBelowAutograd():
180
def backward(ctx, gx):
184
return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
186
lib.impl("foo", Foo.apply, "Autograd")
187
lib.impl("foo", foo_impl, "CPU")
188
lib.impl("foo", foo_impl, "CUDA")
190
x = torch.tensor([0, 1.0], requires_grad=True)
191
with self.assertRaisesRegex(
192
optests.OpCheckError,
193
"_test_custom_op.foo.default",
195
optests.opcheck(op, (x,), {})
197
def test_incorrect_abstract_impl(self, device):
199
lib.define("foo(Tensor x) -> Tensor")
200
op = self.ns().foo.default
202
class Foo(torch.autograd.Function):
205
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
206
guard = torch._C._AutoDispatchBelowAutograd()
207
guard2 = torch._C.ExcludeDispatchKeyGuard(
208
torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
217
def backward(ctx, gx):
224
return x.unsqueeze(1) ** 2
226
lib.impl("foo", Foo.apply, "Autograd")
227
lib.impl("foo", foo_impl, "CPU")
228
lib.impl("foo", foo_impl, "CUDA")
229
lib.impl("foo", foo_meta, "Meta")
231
x = torch.tensor([0, 1.0], requires_grad=True)
232
with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
233
optests.opcheck(op, (x,), {})
235
def test_missing_functionalization(self, device):
237
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
238
op = self.ns().foo.default
240
class Foo(torch.autograd.Function):
244
with torch._C._AutoDispatchBelowAutograd():
248
def backward(ctx, gx):
257
lib.impl("foo", Foo.apply, "Autograd")
258
lib.impl("foo", foo_impl, "CPU")
259
lib.impl("foo", foo_impl, "CUDA")
260
lib.impl("foo", foo_meta, "Meta")
262
x = torch.tensor([0, 1.0])
264
with self.assertRaisesRegex(
265
optests.OpCheckError,
266
"Getting these operators to work with functionalization requires some extra work",
268
optests.opcheck(op, (y,), {})
270
def test_autograd_registered_at_backend(self, device):
272
lib.define("foo(Tensor x) -> Tensor")
273
op = self.ns().foo.default
275
class Foo(torch.autograd.Function):
281
def backward(ctx, gx):
284
lib.impl("foo", Foo.apply, "CPU")
285
lib.impl("foo", Foo.apply, "CUDA")
286
lib.impl("foo", lambda x: x.clone(), "Meta")
288
x = torch.randn([], requires_grad=True)
290
with self.assertRaisesRegex(
291
torch.testing._internal.optests.OpCheckError,
292
"does not have an autograd kernel",
294
optests.opcheck(op, (x,), {})
296
# I'm not sure why this is necessary
299
def test_global_state_mutation(self, device):
301
lib.define("foo(Tensor x) -> Tensor")
302
op = self.ns().foo.default
304
class Foo(torch.autograd.Function):
310
return x.clone() * Foo.invoked
313
def backward(ctx, gx):
316
lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
318
x = torch.tensor(3.14159 / 3, requires_grad=True)
319
with self.assertRaisesRegex(
320
optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
322
optests.opcheck(op, (x,), {})
324
@ops(custom_op_db, dtypes=OpDTypes.any_one)
325
def test_opcheck_opinfo(self, device, dtype, op):
326
for sample_input in op.sample_inputs(
327
device, dtype, requires_grad=op.supports_autograd
329
args = [sample_input.input] + list(sample_input.args)
330
kwargs = sample_input.kwargs
332
torch.ops._torch_testing.numpy_nonzero,
333
torch.ops._torch_testing.numpy_nms,
335
ctx = self.assertRaisesRegex(optests.OpCheckError, "failed with")
337
ctx = contextlib.nullcontext()
345
def test_opcheck_fails_basic(self, device):
346
@custom_op(f"{self.test_ns}::foo")
347
def foo(x: torch.Tensor) -> torch.Tensor:
350
@foo.impl(["cpu", "cuda"])
354
x = torch.randn(3, device=device, requires_grad=True)
355
# Triggers the CustomOp autograd NYI error
356
with self.assertRaisesRegex(
357
optests.OpCheckError, "Autograd has not been implemented for operator"
359
optests.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
361
def test_autograd_registration_check_autograd_kernel(self, device):
363
lib.define("foo(Tensor x) -> Tensor")
364
op = self.ns().foo.default
366
class Foo(torch.autograd.Function):
369
with torch._C._AutoDispatchBelowAutograd():
373
def backward(ctx, gx):
379
lib.impl("foo", Foo.apply, "Autograd")
380
lib.impl("foo", foo_impl, "CPU")
381
lib.impl("foo", foo_impl, "CUDA")
383
x = torch.randn(3, requires_grad=True, device=device)
385
optests.autograd_registration_check(op, (x,), {})
387
def test_autograd_registration_check_compositeimplicitautograd(self, device):
389
lib.define("foo(Tensor x) -> Tensor")
390
op = self.ns().foo.default
395
lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
397
x = torch.randn(3, requires_grad=True, device=device)
399
optests.autograd_registration_check(op, (x,), {})
401
def test_autograd_registration_check_incorrect_composite(self, device):
403
lib.define("foo(Tensor x) -> Tensor")
404
op = self.ns().foo.default
409
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
411
x = torch.randn(3, requires_grad=True, device=device)
412
with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
413
optests.autograd_registration_check(op, (x,), {})
415
def test_autograd_registration_check_incorrect(self, device):
417
lib.define("foo(Tensor x) -> Tensor")
418
op = self.ns().foo.default
420
class Foo(torch.autograd.Function):
426
def backward(ctx, gx):
429
lib.impl("foo", Foo.apply, "CPU")
430
lib.impl("foo", Foo.apply, "CUDA")
432
x = torch.randn(3, requires_grad=True, device=device)
433
with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
434
optests.autograd_registration_check(op, (x,), {})
436
def test_assert_raises_regex(self, device):
437
from torch.testing._internal.optests.aot_autograd import assert_raises_regex
439
with assert_raises_regex(RuntimeError, "c"):
440
raise RuntimeError("abcd")
441
with assert_raises_regex(RuntimeError, "c.*"):
442
raise RuntimeError("abcd")
443
with self.assertRaisesRegex(AssertionError, "instead got"):
444
with assert_raises_regex(RuntimeError, "c.*"):
445
raise ValueError("abcd")
446
with self.assertRaisesRegex(AssertionError, "Expected exception"):
447
with assert_raises_regex(RuntimeError, "c.*"):
449
with self.assertRaisesRegex(AssertionError, "to match regex"):
450
with assert_raises_regex(RuntimeError, "f"):
451
raise RuntimeError("abcd")
454
class TestCustomOp(CustomOpTestCaseBase):
455
test_ns = "_test_custom_op"
457
def test_invalid_schemas(self):
458
# function schmea validation goes through torchgen, so this is just a
460
with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"):
461
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(")
463
def test_invalid_qualname(self):
464
with self.assertRaisesRegex(ValueError, "overload"):
465
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()")
467
def test_name_must_match(self):
468
with self.assertRaisesRegex(ValueError, "to have name"):
470
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
471
def baz(x: Tensor) -> Tensor:
472
raise NotImplementedError()
474
def test_unsupported_schemas(self):
475
with self.assertRaisesRegex(ValueError, "only supports functional"):
476
custom_ops.custom_op(
477
f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)"
479
with self.assertRaisesRegex(ValueError, "only supports functional"):
480
custom_ops.custom_op(
481
f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)"
483
with self.assertRaisesRegex(ValueError, "only supports functional"):
484
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")(
487
with self.assertRaisesRegex(ValueError, "self"):
488
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")(
492
# Tests for the older custom_op API
493
def test_schema_matches_signature(self):
494
with self.assertRaisesRegex(ValueError, "signature to match"):
496
@custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor")
500
with self.assertRaisesRegex(ValueError, "signature to match"):
503
f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor"
508
with self.assertRaisesRegex(ValueError, "signature to match"):
511
f"{TestCustomOp.test_ns}::blah3",
512
"(Tensor x, *, Tensor w, Tensor z) -> Tensor",
514
def blah3(x, *, y, z):
517
with self.assertRaisesRegex(ValueError, "signature to match"):
520
f"{TestCustomOp.test_ns}::blah4",
521
"(Tensor x, *, Tensor z, Tensor y) -> Tensor",
523
def blah4(x, *, y, z):
526
with self.assertRaisesRegex(ValueError, "not supported"):
528
@custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor")
532
with self.assertRaisesRegex(ValueError, "not supported"):
535
f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor"
540
with self.assertRaisesRegex(ValueError, "default arguments"):
543
f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor"
545
def blah7(x=1, *, y):
548
with self.assertRaisesRegex(ValueError, "default arguments"):
551
f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor"
553
def blah8(x, *, y=1):
558
f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor"
563
# Tests for the older custom_op API
564
def test_unsupported_annotation_categories(self):
565
with self.assertRaisesRegex(ValueError, "varargs"):
567
@custom_op(f"{TestCustomOp.test_ns}::foo")
569
raise NotImplementedError()
573
with self.assertRaisesRegex(ValueError, "varkwargs"):
575
@custom_op(f"{TestCustomOp.test_ns}::foo")
577
raise NotImplementedError()
581
with self.assertRaisesRegex(ValueError, "must have a type annotation"):
583
@custom_op(f"{TestCustomOp.test_ns}::foo")
585
raise NotImplementedError()
589
with self.assertRaisesRegex(ValueError, "default value"):
591
@custom_op(f"{TestCustomOp.test_ns}::foo")
592
def foo(x: Optional[Tensor] = None):
593
raise NotImplementedError()
597
with self.assertRaisesRegex(ValueError, "default value"):
599
@custom_op(f"{TestCustomOp.test_ns}::foo")
600
def foo(x: Optional[Tensor] = None):
601
raise NotImplementedError()
605
with self.assertRaisesRegex(ValueError, "unsupported"):
607
@custom_op(f"{TestCustomOp.test_ns}::foo")
608
def foo(x: Tensor) -> Tuple[Tensor, ...]:
609
raise NotImplementedError()
613
def _generate_examples(self, typ):
622
if typ is torch.dtype:
623
return [torch.float32]
624
if typ is torch.device:
625
return [torch.device("cpu")]
626
if typ == torch.types.Number:
628
if typ is torch.Tensor:
629
return [torch.tensor(3)]
630
if typ == Optional[torch.types.Number]:
632
origin = typing.get_origin(typ)
634
args = typing.get_args(typ)
635
assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
636
elt = args[0] if args[1] is type(None) else args[1]
637
return self._generate_examples(elt) + [None]
639
args = typing.get_args(typ)
640
assert len(args) == 1
643
self._generate_examples(elt),
644
self._generate_examples(elt),
645
self._generate_examples(elt),
647
if origin is collections.abc.Sequence:
648
args = typing.get_args(typ)
649
assert len(args) == 1
650
examples = self._generate_examples(args[0])
651
return list(itertools.product(examples, examples)) + []
652
raise NotImplementedError(
653
f"testrunner cannot generate instanstance of type {typ}"
656
def test_supported_return_types_single_return(self):
657
for typ in torch._custom_op.impl.SUPPORTED_RETURN_TYPES:
658
for example in self._generate_examples(typ):
661
@custom_ops.custom_op(f"{self.test_ns}::foo")
662
def foo(x: Tensor) -> typ:
663
raise NotImplementedError()
665
@custom_ops.impl(f"{self.test_ns}::foo")
666
def foo_impl(x: Tensor) -> typ:
669
op = self.get_op(f"{self.test_ns}::foo")
670
result = op(torch.randn([]))
671
self.assertEqual(result, example, msg=f"{typ} {example}")
673
custom_ops._destroy(f"{self.test_ns}::foo")
675
def test_supported_return_types_multi_return(self):
676
for typ in torch._custom_op.impl.SUPPORTED_RETURN_TYPES:
677
for example in self._generate_examples(typ):
680
@custom_ops.custom_op(f"{self.test_ns}::foo")
681
def foo(x: Tensor) -> Tuple[typ, typ]:
682
raise NotImplementedError()
684
@custom_ops.impl(f"{self.test_ns}::foo")
685
def foo_impl(x: Tensor) -> Tuple[typ, typ]:
686
return (example, example)
688
op = self.get_op(f"{self.test_ns}::foo")
689
result = op(torch.randn([]))
690
expected = (example, example)
691
self.assertEqual(result, expected, msg=f"{typ} {example}")
693
custom_ops._destroy(f"{self.test_ns}::foo")
695
def test_supported_param_types(self):
696
for typ in torch._custom_op.impl.SUPPORTED_PARAM_TYPES:
698
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
699
def foo(x: Tensor, y: typ) -> Tensor:
700
raise NotImplementedError()
704
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"])
711
for example in self._generate_examples(typ):
712
op = self.get_op(f"{self.test_ns}::foo")
713
op(torch.randn([]), example)
714
self.assertEqual(yeet, example, msg=f"{typ} {example}")
717
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
719
def test_sequences(self):
720
# Sequence[int] gets automagically turned into int[] in the schema.
721
# This test checks that we actually do support arbitrary sequence types.
722
class MySequence(collections.abc.Sequence):
724
self._container = [1, 2, 3]
726
def __getitem__(self, idx):
727
return self._container[idx]
730
return len(self._container)
732
@custom_ops.custom_op(f"{self.test_ns}::foo")
733
def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
734
raise NotImplementedError()
738
@custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu")
739
def foo_cpu(x, sizes):
742
# Dispatcher will normalize the sequence type into a List
743
self.assertEqual(sizes, [1, 2, 3])
748
op = self.get_op(f"{self.test_ns}::foo")
750
self.assertEqual(called, 1)
752
def test_unsupported_param_types(self):
753
# Not comprehensive (it doesn't need to be), just a check that our mechanism works
754
with self.assertRaisesRegex(ValueError, "unsupported type"):
756
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
757
def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
758
raise NotImplementedError()
762
with self.assertRaisesRegex(ValueError, "unsupported type"):
763
# int[N] in Dispatcher is a bit wild, so we don't try to support it.
764
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
765
def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
766
raise NotImplementedError()
770
with self.assertRaisesRegex(ValueError, "unsupported type"):
771
# We could theoretically support this, but the syntax for suporting
772
# int[] is Sequence[int]
773
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
774
def foo(x: Tensor, y: List[int]) -> Tensor:
775
raise NotImplementedError()
779
with self.assertRaisesRegex(ValueError, "unsupported type"):
781
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
782
def foo(x: Tensor, y: Callable) -> Tensor:
783
raise NotImplementedError()
787
def test_supported_schemas(self):
788
# All of these should already be tested by PyTorch codegen
789
# (we share the same mechanism), but here's a sanity check.
791
"(Tensor x) -> Tensor",
792
"(Tensor x) -> Tensor y",
793
"(Tensor[] x) -> Tensor y",
794
"(Tensor x) -> (Tensor, Tensor)",
795
"(Tensor x) -> (Tensor y, Tensor z)",
796
"(Tensor x) -> (Tensor y, Tensor z)",
799
"(Tensor x, Tensor w) -> (Tensor y, Tensor z)",
800
"(Tensor x, Tensor w) -> (Tensor, Tensor)",
801
"(Tensor x, Tensor w) -> Tensor",
802
"(Tensor? x, Tensor w) -> Tensor",
803
"(Tensor? x, Tensor[] w) -> Tensor",
804
"(Tensor x, int[] w) -> Tensor",
805
"(Tensor x, SymInt[] w) -> Tensor",
806
"(Tensor x, Scalar w) -> Tensor",
807
"(Tensor x, float w) -> Tensor",
808
"(Tensor x, float? w) -> Tensor",
809
"(Tensor x, bool[] w) -> Tensor",
812
for schema in schemas:
813
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema)
814
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
815
for schema in other_schemas:
816
custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema)
817
custom_ops._destroy(f"{TestCustomOp.test_ns}::bar")
819
def test_reserved_ns(self):
820
from torch._custom_op.impl import RESERVED_NS
822
for ns in RESERVED_NS:
823
with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
824
custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor")
826
with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
828
@custom_ops.custom_op(f"{ns}::foo2")
829
def foo2(x: torch.Tensor) -> torch.Tensor:
830
raise NotImplementedError()
832
def test_private_ctor(self):
833
with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
834
CustomOp(None, None, None, None, None)
836
def test_lifetime(self):
837
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
838
def foo(x: torch.Tensor) -> torch.Tensor:
839
raise NotImplementedError()
841
custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
843
# We can't define an op multiple times,
844
with self.assertRaisesRegex(RuntimeError, "multiple times"):
846
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
847
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
848
raise NotImplementedError()
850
# Unless we delete the original op.
851
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
854
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
855
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
856
raise NotImplementedError()
858
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
860
def test_autograd_notimplemented(self):
861
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
862
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
863
raise NotImplementedError()
865
x = torch.randn(3, requires_grad=True)
866
op = self.get_op(f"{self.test_ns}::foo")
867
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
869
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
872
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
873
def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
874
raise NotImplementedError()
876
x = torch.randn(3, requires_grad=True)
878
op = self.get_op(f"{self.test_ns}::foo")
879
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
881
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
884
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
885
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
886
raise NotImplementedError()
888
x = torch.randn(3, requires_grad=True)
890
op = self.get_op(f"{self.test_ns}::foo")
891
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
893
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
895
def test_autograd_notimplemented_gradmode(self):
896
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
897
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
898
raise NotImplementedError()
900
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
904
x = torch.randn(3, requires_grad=True)
906
op = self.get_op(f"{self.test_ns}::foo")
907
with torch.no_grad():
908
# Shouldn't raise, because we are in no_grad
911
def test_impl_cpu(self):
912
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
913
def foo(x: torch.Tensor) -> torch.Tensor:
914
raise NotImplementedError()
916
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
921
op = self.get_op(f"{self.test_ns}::foo")
923
self.assertEqual(result, foo_cpu(x))
925
def test_impl_invalid_devices(self):
926
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
927
def foo(x: torch.Tensor) -> torch.Tensor:
928
raise NotImplementedError()
933
from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
935
for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
936
# Smoke test: should not raise error
937
custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)(
941
# Not supported by this API: we can either support them in the future
942
# or provide some other CustomOp.def_* function. This depends on how
943
# common the use cases are.
944
for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]:
945
with self.assertRaisesRegex(ValueError, "we only support device_type"):
947
f"{TestCustomOp.test_ns}::foo", device_types=invalid_type
950
def test_backward_partially_registered(self):
951
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
952
def foo(x: torch.Tensor) -> torch.Tensor:
953
raise NotImplementedError()
955
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
959
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
960
def foo_backward(ctx, saved, grad):
961
return grad * saved.cos()
963
x = torch.randn([], requires_grad=True)
964
op = self.get_op(f"{self.test_ns}::foo")
965
with self.assertRaisesRegex(
966
RuntimeError, "unable to find a 'save_for_backward'"
971
def test_save_for_backward_inputs_are_namedtuple(self):
972
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
973
def foo(x: torch.Tensor) -> torch.Tensor:
974
raise NotImplementedError()
976
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
982
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
983
def foo_save_for_backward(inputs, output):
986
self.assertTrue(isinstance(inputs, tuple))
987
self.assertEqual(list(inputs._asdict().keys()), ["x"])
990
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
991
def foo_backward(ctx, saved, grad):
992
return {"x": grad * saved.cos()}
994
x = torch.randn([], requires_grad=True)
995
op = self.get_op(f"{self.test_ns}::foo")
997
self.assertEqual(hit, 1)
999
self.assertEqual(hit, 1)
1001
def test_backward_returns_dict(self):
1002
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1003
def foo(x: torch.Tensor) -> torch.Tensor:
1004
raise NotImplementedError()
1006
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1010
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1011
def foo_save_for_backward(inputs, output):
1014
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1015
def foo_backward(ctx, saved, grad):
1016
return grad * saved.cos()
1018
x = torch.randn([], requires_grad=True)
1019
op = self.get_op(f"{self.test_ns}::foo")
1021
with self.assertRaisesRegex(RuntimeError, "to be a dict"):
1024
def test_backward_dict_invalid_keys(self):
1025
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1026
def foo(x: torch.Tensor) -> torch.Tensor:
1027
raise NotImplementedError()
1029
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1033
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1034
def foo_save_for_backward(inputs, output):
1037
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1038
def foo_backward(ctx, saved, grad):
1039
return {"x": grad * saved.cos(), "y": None}
1041
x = torch.randn([], requires_grad=True)
1042
op = self.get_op(f"{self.test_ns}::foo")
1044
with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
1047
def test_backward_dict_grad_for_nontensor(self):
1048
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1049
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1050
raise NotImplementedError()
1052
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1053
def foo_impl(x, dim):
1056
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1057
def foo_save_for_backward(inputs, output):
1060
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1061
def foo_backward(ctx, saved, grad):
1062
return {"x": grad * saved.cos(), "dim": None}
1064
x = torch.randn([], requires_grad=True)
1065
op = self.get_op(f"{self.test_ns}::foo")
1067
with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
1070
def test_backward_dict_requires_keys_for_input_tensors(self):
1071
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1072
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1073
raise NotImplementedError()
1075
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1079
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1080
def foo_save_for_backward(inputs, output):
1083
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1084
def foo_backward(ctx, saved, grad):
1085
return {"x": grad * saved.cos()}
1087
x = torch.randn([], requires_grad=True)
1088
op = self.get_op(f"{self.test_ns}::foo")
1090
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1093
def test_backward_dict_requires_keys_for_input_optional_tensors(self):
1094
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1095
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
1096
raise NotImplementedError()
1098
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1102
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1103
def foo_save_for_backward(inputs, output):
1106
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1107
def foo_backward(ctx, saved, grad):
1108
return {"x": grad * saved.cos()}
1110
x = torch.randn([], requires_grad=True)
1111
op = self.get_op(f"{self.test_ns}::foo")
1113
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1116
def test_backward_grads_are_tensor_or_none(self):
1117
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1118
def foo(x: torch.Tensor) -> torch.Tensor:
1119
raise NotImplementedError()
1121
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1125
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1126
def foo_save_for_backward(inputs, output):
1129
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1130
def foo_backward(ctx, saved, grad):
1131
return {"x": (grad * saved.cos(),)}
1133
x = torch.randn([], requires_grad=True)
1134
op = self.get_op(f"{self.test_ns}::foo")
1136
with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"):
1139
def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
1140
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1141
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1142
raise NotImplementedError()
1144
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1148
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1149
def foo_save_for_backward(inputs, output):
1152
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1153
def foo_backward(ctx, saved, grad):
1154
return {"xs": [grad * saved.cos(), None]}
1156
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1157
op = self.get_op(f"{self.test_ns}::foo")
1159
with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
1162
def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
1163
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1164
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1165
raise NotImplementedError()
1167
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1171
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1172
def foo_save_for_backward(inputs, output):
1175
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1176
def foo_backward(ctx, saved, grad):
1177
return {"xs": [grad * saved.cos(), None, (None,)]}
1179
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1180
op = self.get_op(f"{self.test_ns}::foo")
1182
with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
1185
def test_backward_tensorlist_input_requires_list_grads(self):
1186
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1187
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1188
raise NotImplementedError()
1190
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1194
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1195
def foo_save_for_backward(inputs, output):
1198
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1199
def foo_backward(ctx, saved, grad):
1202
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1203
op = self.get_op(f"{self.test_ns}::foo")
1205
with self.assertRaisesRegex(RuntimeError, "list of gradients"):
1208
def test_backward_output_differentiability_type(self):
1209
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1210
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1211
raise NotImplementedError()
1213
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1215
@custom_ops.impl_backward(
1216
f"{TestCustomOp.test_ns}::foo", output_differentiability=True
1218
def foo_backward(ctx, saved, grad):
1221
def test_backward_output_differentiability_numel(self):
1222
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1223
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
1224
raise NotImplementedError()
1226
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1228
@custom_ops.impl_backward(
1229
f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
1231
def foo_backward(ctx, saved, grad):
1234
def test_backward_output_differentiability_tensorlist(self):
1235
@custom_ops.custom_op(f"{self.test_ns}::foo")
1236
def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
1237
raise NotImplementedError()
1239
@custom_ops.impl(f"{self.test_ns}::foo")
1241
return [x.clone(), x.clone()], x.clone()
1243
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1244
def foo_save_for_backward(inputs, output):
1247
@custom_ops.impl_backward(
1248
f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
1250
def foo_backward(ctx, saved, grad_lst, grad):
1253
op = self.get_op(f"{self.test_ns}::foo")
1254
x = torch.randn(3, requires_grad=True)
1256
self.assertFalse(a.requires_grad)
1257
self.assertFalse(b.requires_grad)
1258
self.assertTrue(c.requires_grad)
1260
def test_backward_output_differentiability_non_tensor(self):
1261
@custom_ops.custom_op(f"{self.test_ns}::foo")
1262
def foo(x: Tensor) -> Tuple[Tensor, int]:
1263
raise NotImplementedError()
1265
@custom_ops.impl(f"{self.test_ns}::foo")
1269
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1270
def foo_save_for_backward(inputs, output):
1273
@custom_ops.impl_backward(
1274
f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
1276
def foo_backward(ctx, saved, grad0, grad1):
1279
op = self.get_op(f"{self.test_ns}::foo")
1280
x = torch.randn(3, requires_grad=True)
1281
with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
1284
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
1285
def test_impl_separate(self):
1286
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1287
def foo(x: torch.Tensor) -> torch.Tensor:
1288
raise NotImplementedError()
1290
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1294
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda")
1299
op = self.get_op(f"{self.test_ns}::foo")
1301
self.assertEqual(result, foo_cpu(x))
1304
op = self.get_op(f"{self.test_ns}::foo")
1306
self.assertEqual(result, foo_cuda(x_cuda))
1308
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
1309
def test_impl_multiple(self):
1310
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1311
def foo(x: torch.Tensor) -> torch.Tensor:
1312
raise NotImplementedError()
1314
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1318
op = self.get_op(f"{self.test_ns}::foo")
1321
self.assertEqual(result, foo_impl(x))
1325
self.assertEqual(result, foo_impl(x_cuda))
1327
def test_impl_abstract_overload(self):
1329
lib.define("sin.blah(Tensor x) -> Tensor")
1331
torch.library.impl_abstract(
1332
f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
1335
op = self.ns().sin.blah
1336
x = torch.randn(3, device="meta")
1339
def test_impl_meta(self):
1340
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1341
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1342
raise NotImplementedError()
1344
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1345
def foo_meta(x, dim):
1346
output_shape = list(x.shape)
1347
del output_shape[dim]
1348
return x.new_empty(output_shape)
1350
x = torch.randn(2, 3, device="meta")
1351
op = self.get_op(f"{self.test_ns}::foo")
1353
self.assertEqual(result.shape, foo_meta(x, 1).shape)
1355
def test_duplicate_impl(self):
1356
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1357
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1358
raise NotImplementedError()
1360
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1361
def foo_meta(x, dim):
1362
output_shape = list(x.shape)
1363
del output_shape[dim]
1364
return x.new_empty(output_shape)
1366
with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
1368
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1369
def foo_meta2(x, dim):
1370
output_shape = list(x.shape)
1371
del output_shape[dim]
1372
return x.new_empty(output_shape)
1374
def test_new_data_dependent_symint(self):
1375
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1376
def foo(x: torch.Tensor) -> torch.Tensor:
1377
raise NotImplementedError()
1379
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1381
ctx = torch.library.get_ctx()
1382
ctx.new_dynamic_size(min=1)
1383
with self.assertRaisesRegex(ValueError, "greater than or equal to 0"):
1384
ctx.new_dynamic_size(min=-1)
1385
with self.assertRaisesRegex(ValueError, "SymInt"):
1386
ctx.new_dynamic_size(max=x.numel())
1387
return torch.clone(x)
1389
x = torch.randn(2, 3, device="cpu")
1390
op = self.get_op(f"{self.test_ns}::foo")
1391
make_fx(op, tracing_mode="symbolic")(x)
1393
def test_meta_for_data_dependent_shape_operation(self):
1394
x = torch.randn(10, device="meta")
1395
with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
1396
torch.ops._torch_testing.numpy_nonzero(x)
1398
def test_basic_make_fx(self):
1399
# More serious tests are in our CustomOp opinfo db,
1400
# this one is just a sanity check.
1401
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1402
def foo(x: torch.Tensor) -> torch.Tensor:
1403
raise NotImplementedError()
1405
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1410
op = self.get_op(f"{self.test_ns}::foo")
1411
gm = make_fx(op, tracing_mode="symbolic")(x)
1412
self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code)
1414
def test_not_implemented_error(self):
1415
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1416
def foo(x: torch.Tensor) -> torch.Tensor:
1417
raise NotImplementedError()
1420
op = self.get_op(f"{self.test_ns}::foo")
1421
with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
1424
x = torch.randn(3, device="meta")
1425
with self.assertRaisesRegex(
1426
NotImplementedError, "no abstract impl or Meta kernel"
1430
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
1431
def bar(sizes: Sequence[int]) -> torch.Tensor:
1432
raise NotImplementedError()
1434
op = self.get_op(f"{self.test_ns}::bar")
1435
with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
1438
def test_abstract_registration_location(self):
1439
custom_op = torch._custom_op.impl._find_custom_op(
1440
"_torch_testing::numpy_nonzero"
1442
source = torch._library.simple_registry.singleton.find(
1443
"_torch_testing::numpy_nonzero"
1444
).abstract_impl.kernel.source
1445
self.assertRegex(source, r".*custom_op_db.py:\d+")
1447
def test_data_dependent_basic(self):
1449
return torch.ops._torch_testing.numpy_nonzero(x)
1451
x = torch.randn(5, 5)
1452
gm = make_fx(f, tracing_mode="symbolic")(x)
1453
self.assertTrue("nonzero" in gm.code)
1455
def test_data_dependent_fake_tracing(self):
1457
return torch.ops._torch_testing.numpy_nonzero(x)
1459
x = torch.randn(5, 5)
1460
# We've updated to attempt to use unbacked symints even for fake
1462
make_fx(f, tracing_mode="fake")(x)
1464
def test_symints(self):
1466
return torch.ops._torch_testing.numpy_view_copy(x, x.shape)
1468
x = torch.randn(2, 3, 4)
1469
gm = make_fx(f, tracing_mode="symbolic")(x)
1471
self.assertEqual(result, f(x))
1472
self.assertExpectedInline(
1475
def forward(self, x_1):
1476
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1477
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
1478
sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
1479
numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
1480
return numpy_view_copy""", # noqa: B950
1483
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
1485
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
1487
def test_data_dependent_compile(self):
1488
import torch._dynamo.testing
1489
from torch._dynamo.utils import counters
1492
cnt = torch._dynamo.testing.CompileCounter()
1494
@torch.compile(backend=cnt)
1496
return torch.ops._torch_testing.numpy_nonzero(x.clone()).clone()
1501
dict(counters["graph_break"]),
1503
"dynamic shape operator: _torch_testing.numpy_nonzero.default; to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True": 1 # noqa: B950
1507
# pre-existing problem: torch.compile(dynamic=True) will, by default,
1508
# graph break on data-dependent operations. Eventually we'll make it so
1509
# that it never graph breaks on data-dependent operations.
1510
@unittest.expectedFailure
1511
def test_data_dependent_nms_dynamic_compile(self):
1512
import torch._dynamo.testing
1513
from torch._dynamo.utils import counters
1516
cnt = torch._dynamo.testing.CompileCounter()
1518
@torch.compile(backend=cnt, dynamic=True)
1520
return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone()
1522
f(torch.randn(20, 4), torch.randn(20), 0.1)
1524
self.assertEqual(len(counters["graph_break"]), 0)
1526
def test_impl_on_existing_op(self):
1528
lib.define("foo(Tensor x) -> Tensor")
1529
qualname = f"{self.test_ns}::foo"
1531
@torch._custom_ops.impl(qualname)
1535
op = self.get_op(qualname)
1538
self.assertEqual(result, x.sin())
1541
"key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"]
1543
def test_impl_on_existing_op_with_cpu_registration(self, key):
1545
lib.define("foo(Tensor x) -> Tensor")
1546
qualname = f"{self.test_ns}::foo"
1551
lib.impl("foo", foo_impl, key)
1552
op = self.get_op(qualname)
1554
with self.assertRaisesRegex(RuntimeError, "already has an implementation"):
1555
custom_ops.impl(qualname, func=foo_impl)
1557
def test_abstract_impl_on_existing_op(self):
1559
lib.define("foo(Tensor x) -> Tensor")
1560
qualname = f"{self.test_ns}::foo"
1562
@torch.library.impl_abstract(qualname, lib=self.lib())
1566
op = self.get_op(qualname)
1567
with torch._subclasses.FakeTensorMode():
1570
self.assertEqual(result.shape, x.shape)
1571
self.assertEqual(result.stride(), x.stride())
1573
def test_abstract_impl_on_existing_op_with_meta(self):
1575
lib.define("foo(Tensor x) -> Tensor")
1576
qualname = f"{self.test_ns}::foo"
1581
lib.impl("foo", foo_impl, "Meta")
1582
op = self.get_op(qualname)
1584
with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
1585
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1587
def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
1589
lib.define("foo(Tensor x) -> Tensor")
1590
qualname = f"{self.test_ns}::foo"
1595
lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
1596
op = self.get_op(qualname)
1598
with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
1599
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1601
def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
1603
lib.define("foo(Tensor x) -> Tensor")
1604
qualname = f"{self.test_ns}::foo"
1609
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
1610
op = self.get_op(qualname)
1612
torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
1613
with torch._subclasses.FakeTensorMode():
1616
self.assertEqual(result.shape, ())
1618
def _test_backward_impl_raises(self, qualname, err_regex):
1619
with self.assertRaisesRegex(RuntimeError, err_regex):
1621
@custom_ops.impl_save_for_backward(qualname)
1625
with self.assertRaisesRegex(RuntimeError, err_regex):
1627
@custom_ops.impl_backward(qualname)
1631
def test_backward_impl_on_existing_op_incorrect_schema_views(self):
1633
lib.define("foo(Tensor(a) x) -> Tensor(a)")
1634
qualname = f"{self.test_ns}::foo"
1635
self._test_backward_impl_raises(qualname, "operator that returns views")
1637
def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
1639
lib.define("foo(Tensor(a!) x) -> Tensor")
1640
qualname = f"{self.test_ns}::foo"
1641
self._test_backward_impl_raises(qualname, "non-functional")
1643
def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
1645
lib.define("foo(Tensor x) -> ()")
1646
qualname = f"{self.test_ns}::foo"
1647
self._test_backward_impl_raises(qualname, "no returns")
1649
def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
1651
lib.define("foo(Tensor x) -> Tensor")
1652
qualname = f"{self.test_ns}::foo"
1653
lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
1654
self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
1656
@parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
1657
def test_backward_impl_on_existing_op_with_key(self, key):
1659
lib.define("foo(Tensor x) -> Tensor")
1660
qualname = f"{self.test_ns}::foo"
1661
lib.impl("foo", lambda x: x.sin().cos(), key)
1662
self._test_backward_impl_raises(qualname, key)
1664
def test_backward_impl_on_existing_op(self):
1666
lib.define("foo(Tensor x) -> Tensor")
1667
qualname = f"{self.test_ns}::foo"
1669
@custom_ops.impl(qualname)
1671
with torch.no_grad():
1674
@custom_ops.impl_save_for_backward(qualname)
1675
def foo_save_for_backward(inputs, output):
1678
@custom_ops.impl_backward(qualname)
1679
def foo_backward(ctx, saved, grad_out):
1680
return {"x": grad_out * saved.cos()}
1682
op = self.get_op(qualname)
1683
x = torch.randn([], requires_grad=True)
1685
(gx,) = torch.autograd.grad(y, x)
1686
self.assertEqual(gx, x.cos())
1691
subtest(torch.Tag.pointwise, "single"),
1692
subtest((torch.Tag.pointwise,), "tuple"),
1693
subtest([torch.Tag.pointwise], "list"),
1696
def test_define_with_tags(self, tags):
1698
tags = (torch.Tag.pointwise,)
1699
torch.library.define(
1700
f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags
1702
actual = self.ns().foo.default.tags
1703
self.assertTrue(isinstance(actual, list))
1704
self.assertEqual(actual, list(tags))
1706
def test_builtin_aten_ops_are_pt2_compliant(self):
1707
for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
1708
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1710
def test_builtin_torchscript_ops(self):
1711
for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]:
1712
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1714
def test_autogen_aten_ops_are_pt2_compliant(self):
1716
torch.ops.aten._foreach_copy.default,
1717
torch.ops.aten.fill.Tensor_out,
1719
self.assertIn(torch.Tag.generated, op.tags)
1720
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1722
def test_resolve_packet(self):
1724
result = torch._C._jit_resolve_packet("aten::sum", x)
1725
self.assertEqual(result, "default")
1727
result = torch._C._jit_resolve_packet("aten::sum", x, dim=1)
1728
self.assertEqual(result, "dim_IntList")
1730
with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1731
result = torch._C._jit_resolve_packet("aten::sum", x, x, x)
1733
def test_define_bad_schema(self):
1735
with self.assertRaisesRegex(ValueError, "expected schema to look like"):
1736
torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor")
1738
def test_define_and_impl(self):
1740
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1742
@torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib)
1744
return torch.from_numpy(np.sin(x.numpy()))
1747
y = self.ns().foo(x)
1748
assert torch.allclose(y, x.sin())
1750
def test_define_validation(self):
1751
with self.assertRaisesRegex(ValueError, "namespace"):
1752
torch.library.define("foo", "(Tensor x) -> Tensor")
1754
def test_legacy_define(self):
1757
@torch.library.define(lib, "foo(Tensor x) -> Tensor")
1759
return torch.from_numpy(np.sin(x.numpy()))
1762
y = self.ns().foo(x)
1763
assert torch.allclose(y, x.sin())
1765
def test_impl_function(self):
1767
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1770
return torch.from_numpy(np.sin(x.numpy()))
1772
torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib)
1774
y = self.ns().foo(x)
1775
assert torch.allclose(y, x.sin())
1777
def test_legacy_impl(self):
1779
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1781
@torch.library.impl(lib, "foo", "CPU")
1783
return torch.from_numpy(np.sin(x.numpy()))
1786
y = self.ns().foo(x)
1787
assert torch.allclose(y, x.sin())
1789
def test_defined_in_python(self):
1790
self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
1791
self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
1794
torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1796
self.assertTrue(ns.foo.default._defined_in_python)
1798
torch.library.define(
1799
"{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
1801
self.assertTrue(ns.bar.overload._defined_in_python)
1803
def _test_impl_device(self, name, types, device):
1805
torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)
1807
@torch.library.impl(f"{self.test_ns}::{name}", types)
1809
x_np = x.cpu().numpy()
1810
y = torch.from_numpy(np.sin(x_np))
1811
return y.to(device=x.device)
1813
x = torch.randn(3, device=device)
1814
y = getattr(self.ns(), name)(x)
1815
assert torch.allclose(y, x.sin())
1817
def test_impl_device_cpu(self):
1818
self._test_impl_device("foo1", "default", "cpu")
1819
self._test_impl_device("foo2", ["cpu"], "cpu")
1820
self._test_impl_device("foo3", ["cpu", "cuda"], "cpu")
1822
@unittest.skipIf(not TEST_CUDA, "requires cuda")
1823
def test_impl_device_cuda(self):
1824
self._test_impl_device("foo4", "default", "cuda")
1825
self._test_impl_device("foo5", ["cuda"], "cuda")
1826
self._test_impl_device("foo6", ["cpu", "cuda"], "cuda")
1828
def test_impl_device_function(self):
1830
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1833
x_np = x.cpu().numpy()
1834
y = torch.from_numpy(np.sin(x_np))
1835
return y.to(device=x.device)
1837
torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib)
1839
y = self.ns().foo(x)
1840
assert torch.allclose(y, x.sin())
1842
def test_impl_device_invalid(self):
1843
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
1844
torch.library.impl("blah::blah", "somethingsomething")
1846
def test_autograd_function_backed_op(self):
1848
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1849
static constexpr bool is_traceable = true;
1851
static torch::Tensor forward(
1852
torch::autograd::AutogradContext* ctx,
1857
static torch::autograd::variable_list backward(
1858
torch::autograd::AutogradContext *ctx,
1859
torch::autograd::variable_list grad_output) {
1864
torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
1865
return CustomOpAutogradFunction::apply(x);
1868
TORCH_LIBRARY(mylib, m) {
1869
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1873
module = torch.utils.cpp_extension.load_inline(
1875
cpp_sources=cpp_source,
1876
functions="custom_op_backed_by_autograd_fn",
1880
x = torch.ones(2, 2, requires_grad=True)
1881
temp = x.clone().detach()
1882
out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
1885
self.assertEqual(x.grad, temp)
1888
def op_with_incorrect_schema(testcase, name):
1889
lib = testcase.lib()
1890
lib.define(f"{name}(Tensor x) -> Tensor")
1891
qualname = f"{testcase.test_ns}::{name}"
1892
lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd")
1893
return testcase.get_op(qualname)
1896
class MiniOpTest(CustomOpTestCaseBase):
1897
test_ns = "mini_op_test"
1899
def _init_op_delayed_backward_error(self):
1900
name = "delayed_error"
1901
qualname = f"{self.test_ns}::{name}"
1903
lib.define(f"{name}(Tensor x) -> Tensor")
1904
lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd")
1905
op = self.get_op(qualname)
1907
class Op(torch.autograd.Function):
1909
def forward(ctx, x):
1910
with torch._C._AutoDispatchBelowAutograd():
1914
def backward(ctx, grad):
1915
raise NotImplementedError()
1917
def autograd_impl(x):
1920
lib.impl(name, autograd_impl, "Autograd")
1923
def _init_op_with_no_abstract_impl(self):
1924
name = "no_abstract"
1925
qualname = f"{self.test_ns}::{name}"
1927
lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,))
1928
lib.impl(name, lambda x: x.clone(), "CPU")
1929
return torch._library.utils.lookup_op(qualname)
1933
self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl()
1934
self._op_delayed_backward_error = self._init_op_delayed_backward_error()
1936
@optests.dontGenerateOpCheckTests("Testing this API")
1937
def test_dont_generate(self):
1938
op = op_with_incorrect_schema(self, "incorrect_schema")
1943
x = torch.randn(2, 3, requires_grad=True)
1944
y = torch.randn(3, 5)
1945
result = torch.ops.aten.mm.default(x, y)
1946
self.assertEqual(result, x @ y)
1948
def test_mm_meta(self):
1949
x = torch.randn(2, 3, requires_grad=True, device="meta")
1950
y = torch.randn(3, 5, device="meta")
1951
result = torch.ops.aten.mm.default(x, y)
1952
self.assertEqual(result.shape, (x @ y).shape)
1954
def test_mm_fake(self):
1955
with torch._subclasses.fake_tensor.FakeTensorMode():
1956
x = torch.randn(2, 3, requires_grad=True, device="cpu")
1957
y = torch.randn(3, 5, device="cpu")
1958
result = torch.ops.aten.mm.default(x, y)
1959
self.assertEqual(result.shape, (x @ y).shape)
1961
def test_mm_errors(self):
1962
x = torch.randn(2, 3, requires_grad=True)
1963
y = torch.randn(4, 5)
1964
with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"):
1965
result = torch.ops.aten.mm.default(x, y)
1967
def test_nonzero(self):
1968
x = torch.tensor([0, 1, 2, 0, 0])
1969
y = torch.ops.aten.nonzero.default(x)
1970
self.assertEqual(y, torch.tensor([[1], [2]]))
1972
def test_inplace(self):
1975
y = torch.ops.aten.sin_(x)
1976
self.assertEqual(x, x_clone.sin())
1978
def test_incorrect_schema(self):
1979
op = op_with_incorrect_schema(self, "incorrect_schema")
1983
def test_no_abstract(self):
1984
op = self._op_with_no_abstract_impl
1988
def test_delayed_error(self):
1989
op = self._op_delayed_backward_error
1990
x = torch.randn([], requires_grad=True)
1992
with self.assertRaises(NotImplementedError):
1995
def test_delayed_error_no_requires_grad(self):
1996
op = self._op_delayed_backward_error
2001
class MiniOpTestOther(CustomOpTestCaseBase):
2002
test_ns = "mini_op_test"
2004
def test_nonzero_again(self):
2005
x = torch.tensor([0, 1, 2, 0, 0])
2006
y = torch.ops.aten.nonzero.default(x)
2007
self.assertEqual(y, torch.tensor([[1], [2]]))
2010
optests.generate_opcheck_tests(
2012
["aten", "mini_op_test"],
2014
os.path.dirname(__file__),
2015
"minioptest_failures_dict.json",
2017
additional_decorators={
2018
"test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure]
2022
optests.generate_opcheck_tests(
2024
["aten", "mini_op_test"],
2026
os.path.dirname(__file__),
2027
"minioptest_failures_dict.json",
2032
class TestGenerateOpcheckTests(CustomOpTestCaseBase):
2033
def test_MiniOpTest(self):
2034
for orig_test in ["test_mm", "test_nonzero"]:
2037
) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS:
2038
expected_test = f"{test}__{orig_test}"
2039
self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test)
2041
def test_generate_repro_save_data(self):
2042
from torch.testing._internal.optests.generate_tests import generate_repro
2044
args = (torch.ones(2, 2),)
2045
kwargs = {"mat2": torch.zeros(2, 2)}
2046
actual = generate_repro(
2048
torch.ops.aten.sin.default,
2054
actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual)
2055
self.assertExpectedInline(
2058
# =========================================================
2060
# =========================================================
2062
from torch.testing._internal.optests import opcheck
2064
# Make sure you have loaded the library that contains the op
2065
# via an import or torch.ops.load_library(...)
2066
op = torch.ops.aten.sin.default
2068
args, kwargs = torch.load("repro.pt")
2069
opcheck(op, args, kwargs, test_utils="test_schema")
2070
# =========================================================
2072
# =========================================================
2076
def test_generate_repro_no_save_data(self):
2077
from torch.testing._internal.optests.generate_tests import generate_repro
2079
args = (torch.ones(2, 2),)
2080
kwargs = {"mat2": torch.zeros(2, 2)}
2081
actual = generate_repro(
2083
torch.ops.aten.sin.default,
2089
self.assertExpectedInline(
2092
# =========================================================
2094
# =========================================================
2096
from torch.testing._internal.optests import opcheck
2098
# Make sure you have loaded the library that contains the op
2099
# via an import or torch.ops.load_library(...)
2100
op = torch.ops.aten.sin.default
2102
# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1
2103
# we will fill them in same (args, kwargs) as in your test
2104
args = () # args to the operator
2105
kwargs = {} # kwargs to the operator
2106
opcheck(op, args, kwargs, test_utils="test_schema")
2107
# =========================================================
2109
# =========================================================
2113
def test_failures_dict_validation(self):
2114
from torch.testing._internal.optests.generate_tests import (
2116
validate_failures_dict_structure,
2120
"mini_op_test::incorrect_schema": {
2121
"MiniOpTest.test_aot_dispatch_static__test_delayed_error": {
2123
"status": "success",
2127
with self.assertRaisesRegex(RuntimeError, "got status=success"):
2128
validate_failures_dict_structure(
2129
FailuresDict("", failures),
2130
torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
2135
"mini_op_test::incorrect_schema": {
2136
"MiniOpTest.test_aot_dispatch__test_delayed_error": {
2142
with self.assertRaisesRegex(RuntimeError, "should begin with one of"):
2143
validate_failures_dict_structure(
2144
FailuresDict("", failures),
2145
torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
2150
"mini_op_test::incorrect_schema": {
2151
"MiniOpTest.test_aot_dispatch_static__test_delayed_error_nopenopenope": {
2157
with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"):
2158
validate_failures_dict_structure(
2159
FailuresDict("", failures),
2160
torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
2164
def test_dont_generate_decorator(self):
2165
self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
2166
self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
2168
def test_opcheck(self):
2169
x = torch.randn(3, requires_grad=True)
2170
with self.assertRaisesRegex(ValueError, "OpOverload"):
2171
optests.opcheck(torch.sin, (x,))
2172
with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
2173
optests.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
2174
result = optests.opcheck(torch.ops.aten.sin.default, (x,))
2179
"test_schema": "SUCCESS",
2180
"test_autograd_registration": "SUCCESS",
2181
"test_faketensor": "SUCCESS",
2182
"test_aot_dispatch_static": "SUCCESS",
2183
"test_aot_dispatch_dynamic": "SUCCESS",
2187
result = optests.opcheck(
2188
torch.ops.aten.sin.default, (x,), test_utils="test_schema"
2193
"test_schema": "SUCCESS",
2197
result = optests.opcheck(
2198
torch.ops.aten.sin.default,
2200
test_utils=["test_schema", "test_faketensor"],
2205
"test_schema": "SUCCESS",
2206
"test_faketensor": "SUCCESS",
2210
def test_is_inside_opcheck_mode(self):
2211
self.assertFalse(optests.is_inside_opcheck_mode())
2212
with optests.generate_tests.OpCheckMode(
2213
["foo"], "bar", lambda x: x, None, "baz", "brr"
2215
self.assertTrue(optests.is_inside_opcheck_mode())
2217
def test_opcheck_bad_op(self):
2218
op = op_with_incorrect_schema(self, "foo")
2220
with self.assertRaisesRegex(Exception, "is not defined to alias output"):
2221
optests.opcheck(op, (x,))
2223
result = optests.opcheck(op, (x,), raise_exception=False)
2224
self.assertTrue(isinstance(result["test_schema"], RuntimeError))
2225
del result["test_schema"]
2229
"test_autograd_registration": "SUCCESS",
2230
"test_faketensor": "SUCCESS",
2231
"test_aot_dispatch_static": "SUCCESS",
2232
"test_aot_dispatch_dynamic": "SUCCESS",
2237
only_for = ("cpu", "cuda")
2238
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
2239
instantiate_parametrized_tests(TestCustomOp)
2241
if __name__ == "__main__":