1
# Owner(s): ["module: custom-operators"]
11
from typing import * # noqa: F403
15
import torch._custom_ops as custom_ops
16
import torch.testing._internal.optests as optests
17
import torch.utils._pytree as pytree
18
import torch.utils.cpp_extension
19
from functorch import make_fx
20
from torch import Tensor
21
from torch._custom_op.impl import CustomOp, infer_schema
22
from torch._library.infer_schema import tuple_to_list
23
from torch._utils_internal import get_file_path_2
24
from torch.testing._internal import custom_op_db
25
from torch.testing._internal.common_cuda import TEST_CUDA
26
from torch.testing._internal.common_device_type import (
27
instantiate_device_type_tests,
31
from torch.testing._internal.common_utils import (
32
instantiate_parametrized_tests,
40
from torch.testing._internal.custom_op_db import numpy_nonzero
43
# Shadowed by `torch.testing._internal.common_utils.custom_op`
44
from torch._custom_op.impl import custom_op # usort: skip
47
def requires_compile(fun):
48
fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun)
52
class CustomOpTestCaseBase(TestCase):
53
test_ns = "_test_custom_op"
61
import torch._custom_op
63
keys = list(torch._custom_op.impl.global_registry.keys())
65
if not key.startswith(f"{self.test_ns}::"):
67
torch._custom_op.impl.global_registry[key]._destroy()
68
if hasattr(torch.ops, self.test_ns):
69
delattr(torch.ops, self.test_ns)
70
for lib in self.libraries:
75
return getattr(torch.ops, self.test_ns)
78
result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901
79
self.libraries.append(result)
82
def get_op(self, qualname):
83
return torch._custom_op.impl.get_op(qualname)
87
class TestCustomOpTesting(CustomOpTestCaseBase):
88
@parametrize("check_gradients", (False, "auto"))
89
@parametrize("dynamic", (True, False))
90
def test_aot_autograd_check_degenerate_cases(
91
self, device, dynamic, check_gradients
97
x = torch.randn(3, device=device)
98
optests.aot_autograd_check(
99
simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
102
def outputs_dont_require_grad(x):
106
y = torch.randn(3, device=device, requires_grad=True)
107
optests.aot_autograd_check(
108
simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
115
x = torch.randn(3, device=device, requires_grad=True)
116
y = torch.randn(3, device=device, requires_grad=False)
117
optests.aot_autograd_check(
118
no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
120
optests.aot_autograd_check(
121
no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
124
def test_incorrect_schema_mutation(self, device):
126
lib.define("foo(Tensor x) -> Tensor")
127
op = self.ns().foo.default
129
class Foo(torch.autograd.Function):
132
guard = torch._C._AutoDispatchBelowAutograd()
139
def backward(ctx, gx):
146
lib.impl("foo", Foo.apply, "Autograd")
147
lib.impl("foo", foo_impl, "CPU")
148
lib.impl("foo", foo_impl, "CUDA")
150
x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
151
with self.assertRaisesRegex(
152
optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
154
torch.library.opcheck(op, (x,), {})
156
def test_incorrect_schema_view(self, device):
158
lib.define("foo(Tensor x) -> Tensor")
159
op = self.ns().foo.default
161
class Foo(torch.autograd.Function):
164
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
165
with torch._C._AutoDispatchBelowAutograd():
166
with torch._C._ExcludeDispatchKeyGuard(
167
torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
172
def backward(ctx, gx):
181
lib.impl("foo", Foo.apply, "Autograd")
182
lib.impl("foo", foo_impl, "CPU")
183
lib.impl("foo", foo_meta, "Meta")
185
x = torch.tensor(3.14159 / 3, requires_grad=True)
186
with self.assertRaisesRegex(
187
optests.OpCheckError,
188
"Argument x is not defined to alias output but was aliasing",
190
torch.library.opcheck(op, (x,), {})
192
def test_missing_abstract_impl(self, device):
194
lib.define("foo(Tensor x) -> Tensor")
195
op = self.ns().foo.default
197
class Foo(torch.autograd.Function):
200
with torch._C._AutoDispatchBelowAutograd():
204
def backward(ctx, gx):
208
return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
210
lib.impl("foo", Foo.apply, "Autograd")
211
lib.impl("foo", foo_impl, "CPU")
212
lib.impl("foo", foo_impl, "CUDA")
214
x = torch.tensor([0, 1.0], requires_grad=True)
215
with self.assertRaisesRegex(
216
optests.OpCheckError,
217
"_test_custom_op.foo.default",
219
torch.library.opcheck(op, (x,), {})
221
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
222
def test_incorrect_abstract_impl(self, device):
224
lib.define("foo(Tensor x) -> Tensor")
225
op = self.ns().foo.default
227
class Foo(torch.autograd.Function):
230
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
231
guard = torch._C._AutoDispatchBelowAutograd()
232
guard2 = torch._C.ExcludeDispatchKeyGuard(
233
torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
242
def backward(ctx, gx):
249
return x.unsqueeze(1) ** 2
251
lib.impl("foo", Foo.apply, "Autograd")
252
lib.impl("foo", foo_impl, "CPU")
253
lib.impl("foo", foo_impl, "CUDA")
254
lib.impl("foo", foo_meta, "Meta")
256
x = torch.tensor([0, 1.0], requires_grad=True)
257
with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
258
torch.library.opcheck(op, (x,), {})
260
def test_missing_functionalization(self, device):
262
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
263
op = self.ns().foo.default
265
class Foo(torch.autograd.Function):
269
with torch._C._AutoDispatchBelowAutograd():
273
def backward(ctx, gx):
282
lib.impl("foo", Foo.apply, "Autograd")
283
lib.impl("foo", foo_impl, "CPU")
284
lib.impl("foo", foo_impl, "CUDA")
285
lib.impl("foo", foo_meta, "Meta")
287
x = torch.tensor([0, 1.0])
289
with self.assertRaisesRegex(
290
optests.OpCheckError,
291
"We only support functionalizing operators whose outputs do not have alias annotations",
293
torch.library.opcheck(op, (y,), {})
295
def test_autograd_registered_at_backend(self, device):
297
lib.define("foo(Tensor x) -> Tensor")
298
op = self.ns().foo.default
300
class Foo(torch.autograd.Function):
306
def backward(ctx, gx):
309
lib.impl("foo", Foo.apply, "CPU")
310
lib.impl("foo", Foo.apply, "CUDA")
311
lib.impl("foo", lambda x: x.clone(), "Meta")
313
x = torch.randn([], requires_grad=True)
315
with self.assertRaisesRegex(
316
torch.testing._internal.optests.OpCheckError,
317
"does not have an autograd kernel",
319
torch.library.opcheck(op, (x,), {})
321
# I'm not sure why this is necessary
324
def test_global_state_mutation(self, device):
326
lib.define("foo(Tensor x) -> Tensor")
327
op = self.ns().foo.default
329
class Foo(torch.autograd.Function):
335
return x.clone() * Foo.invoked
338
def backward(ctx, gx):
341
lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
343
x = torch.tensor(3.14159 / 3, requires_grad=True)
344
with self.assertRaisesRegex(
345
optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
347
torch.library.opcheck(op, (x,), {})
349
@ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
350
def test_opcheck_opinfo(self, device, dtype, op):
351
for sample_input in op.sample_inputs(
352
device, dtype, requires_grad=op.supports_autograd
354
args = [sample_input.input] + list(sample_input.args)
355
kwargs = sample_input.kwargs
356
torch.library.opcheck(op.op, args, kwargs)
358
def test_opcheck_fails_basic(self, device):
359
@custom_op(f"{self.test_ns}::foo")
360
def foo(x: torch.Tensor) -> torch.Tensor: ...
362
@foo.impl(["cpu", "cuda"])
366
x = torch.randn(3, device=device, requires_grad=True)
367
# Triggers the CustomOp autograd NYI error
368
with self.assertRaisesRegex(
369
optests.OpCheckError, "Autograd has not been implemented for operator"
371
torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
373
def test_autograd_registration_check_autograd_kernel(self, device):
375
lib.define("foo(Tensor x) -> Tensor")
376
op = self.ns().foo.default
378
class Foo(torch.autograd.Function):
381
with torch._C._AutoDispatchBelowAutograd():
385
def backward(ctx, gx):
391
lib.impl("foo", Foo.apply, "Autograd")
392
lib.impl("foo", foo_impl, "CPU")
393
lib.impl("foo", foo_impl, "CUDA")
395
x = torch.randn(3, requires_grad=True, device=device)
397
optests.autograd_registration_check(op, (x,), {})
399
def test_autograd_registration_check_compositeimplicitautograd(self, device):
401
lib.define("foo(Tensor x) -> Tensor")
402
op = self.ns().foo.default
407
lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
409
x = torch.randn(3, requires_grad=True, device=device)
411
optests.autograd_registration_check(op, (x,), {})
413
def test_autograd_registration_check_incorrect_composite(self, device):
415
lib.define("foo(Tensor x) -> Tensor")
416
op = self.ns().foo.default
421
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
423
x = torch.randn(3, requires_grad=True, device=device)
424
with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
425
optests.autograd_registration_check(op, (x,), {})
427
def test_autograd_registration_check_incorrect(self, device):
429
lib.define("foo(Tensor x) -> Tensor")
430
op = self.ns().foo.default
432
class Foo(torch.autograd.Function):
438
def backward(ctx, gx):
441
lib.impl("foo", Foo.apply, "CPU")
442
lib.impl("foo", Foo.apply, "CUDA")
444
x = torch.randn(3, requires_grad=True, device=device)
445
with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
446
optests.autograd_registration_check(op, (x,), {})
448
def test_assert_raises_regex(self, device):
449
from torch.testing._internal.optests.aot_autograd import assert_raises_regex
451
with assert_raises_regex(RuntimeError, "c"):
452
raise RuntimeError("abcd")
453
with assert_raises_regex(RuntimeError, "c.*"):
454
raise RuntimeError("abcd")
455
with self.assertRaisesRegex(AssertionError, "instead got"):
456
with assert_raises_regex(RuntimeError, "c.*"):
457
raise ValueError("abcd")
458
with self.assertRaisesRegex(AssertionError, "Expected exception"):
459
with assert_raises_regex(RuntimeError, "c.*"):
461
with self.assertRaisesRegex(AssertionError, "to match regex"):
462
with assert_raises_regex(RuntimeError, "f"):
463
raise RuntimeError("abcd")
466
class TestCustomOp(CustomOpTestCaseBase):
467
test_ns = "_test_custom_op"
470
def test_functionalize_error(self):
471
with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib:
472
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
477
lib.impl("foo", foo, "CompositeExplicitAutograd")
478
foo_op = self.get_op(f"{self.test_ns}::foo")
480
lib.define("bar(Tensor(a) x) -> Tensor(a)")
485
lib.impl("bar", bar, "CompositeExplicitAutograd")
486
bar_op = self.get_op(f"{self.test_ns}::bar")
488
msg = r".*We only support functionalizing operators whose outputs do not have alias annotations"
492
@torch.compile(backend="aot_eager", fullgraph=True)
496
@torch.compile(backend="aot_eager", fullgraph=True)
500
with self.assertRaisesRegex(RuntimeError, msg):
502
with self.assertRaisesRegex(RuntimeError, msg):
505
def test_invalid_schemas(self):
506
# function schmea validation goes through torchgen, so this is just a
508
with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"):
509
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(")
511
def test_invalid_qualname(self):
512
with self.assertRaisesRegex(ValueError, "overload"):
513
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()")
515
def test_name_must_match(self):
516
with self.assertRaisesRegex(ValueError, "to have name"):
518
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
519
def baz(x: Tensor) -> Tensor:
520
raise NotImplementedError
522
def test_unsupported_schemas(self):
523
with self.assertRaisesRegex(ValueError, "only supports functional"):
524
custom_ops.custom_op(
525
f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)"
527
with self.assertRaisesRegex(ValueError, "only supports functional"):
528
custom_ops.custom_op(
529
f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)"
531
with self.assertRaisesRegex(ValueError, "only supports functional"):
532
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")(
535
with self.assertRaisesRegex(ValueError, "self"):
536
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")(
540
# Tests for the older custom_op API
541
def test_schema_matches_signature(self):
542
with self.assertRaisesRegex(ValueError, "signature to match"):
544
@custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor")
548
with self.assertRaisesRegex(ValueError, "signature to match"):
551
f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor"
556
with self.assertRaisesRegex(ValueError, "signature to match"):
559
f"{TestCustomOp.test_ns}::blah3",
560
"(Tensor x, *, Tensor w, Tensor z) -> Tensor",
562
def blah3(x, *, y, z):
565
with self.assertRaisesRegex(ValueError, "signature to match"):
568
f"{TestCustomOp.test_ns}::blah4",
569
"(Tensor x, *, Tensor z, Tensor y) -> Tensor",
571
def blah4(x, *, y, z):
574
with self.assertRaisesRegex(ValueError, "not supported"):
576
@custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor")
580
with self.assertRaisesRegex(ValueError, "not supported"):
583
f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor"
588
with self.assertRaisesRegex(ValueError, "default arguments"):
591
f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor"
593
def blah7(x=1, *, y):
596
with self.assertRaisesRegex(ValueError, "default arguments"):
599
f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor"
601
def blah8(x, *, y=1):
606
f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor"
611
def test_infer_schema_no_return(self):
612
with self.assertRaisesRegex(
613
ValueError, "No return type annotation was provided. Please add one."
616
@torch.library.custom_op("mylib::foo", mutates_args={})
617
def foo(x: torch.Tensor, y: int):
620
def test_infer_schema_supported(self):
621
def a(x: Tensor) -> Tensor:
622
return torch.empty([])
624
self.assertExpectedInline(
625
infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor"""
628
def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor:
629
return torch.empty([])
631
self.assertExpectedInline(
632
infer_schema(kwonly1, mutates_args=()),
633
"""(Tensor x, *, SymInt y, float z) -> Tensor""",
636
def kwonly2(*, y: Tensor) -> Tensor:
637
return torch.empty([])
639
self.assertExpectedInline(
640
infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor"""
650
d: torch.types.Number,
651
) -> Tuple[Tensor, int, float, bool]:
652
return torch.empty([]), 1, 0.1, True
654
self.assertExpectedInline(
655
infer_schema(b, mutates_args=()),
656
"""(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""",
663
w: Sequence[Optional[Tensor]],
665
return [torch.empty([])]
667
self.assertExpectedInline(
668
infer_schema(c, mutates_args=()),
669
"""(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""",
672
def d(x: Tensor) -> Tuple[List[Tensor], Tensor]:
673
return [torch.empty([])], torch.empty([])
675
self.assertExpectedInline(
676
infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)"""
680
return torch.empty([])
682
self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""")
684
def f(x: Tensor) -> None:
687
self.assertExpectedInline(
688
infer_schema(f, mutates_args=()), """(Tensor x) -> ()"""
692
x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
696
self.assertExpectedInline(
697
infer_schema(g, mutates_args=()),
698
"""(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""",
701
self.assertExpectedInline(
702
infer_schema(g, mutates_args={"x", "w", "z"}),
703
"""(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
706
self.assertExpectedInline(
707
infer_schema(g, mutates_args="unknown"),
708
"""(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
713
a: Optional[int] = None,
718
f: torch.dtype = torch.float,
719
g: torch.dtype = torch.float32,
720
h: torch.dtype = torch.int,
721
i: torch.device = torch.device("cpu:0"),
722
j: torch.device = "cpu",
726
self.assertExpectedInline(
727
infer_schema(h, mutates_args=()),
729
"""(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
730
"""ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
734
def foo_impl(x: torch.Tensor) -> torch.Tensor:
737
schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
738
self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
740
def test_infer_schema_unsupported(self):
741
with self.assertRaisesRegex(ValueError, "varargs"):
744
raise NotImplementedError
746
infer_schema(foo, mutates_args=())
748
with self.assertRaisesRegex(ValueError, "varkwargs"):
751
raise NotImplementedError
753
infer_schema(foo, mutates_args=())
755
with self.assertRaisesRegex(ValueError, "must have a type annotation"):
758
raise NotImplementedError
760
infer_schema(foo, mutates_args=())
762
with self.assertRaisesRegex(ValueError, "unsupported"):
764
def foo(x: Tensor) -> Tuple[Tensor, ...]:
765
raise NotImplementedError
767
infer_schema(foo, mutates_args=())
769
with self.assertRaisesRegex(ValueError, "can be mutated"):
771
def foo(x: Tensor, y: int) -> Tensor:
772
raise NotImplementedError
774
infer_schema(foo, mutates_args={"y"})
776
def _generate_examples(self, typ):
785
if typ is torch.dtype:
786
return [torch.float32]
787
if typ is torch.device:
788
return [torch.device("cpu")]
789
if typ == torch.types.Number:
791
if typ is torch.Tensor:
792
return [torch.tensor(3)]
793
if typ == Optional[torch.types.Number]:
795
origin = typing.get_origin(typ)
797
args = typing.get_args(typ)
798
assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
799
elt = args[0] if args[1] is type(None) else args[1]
800
return self._generate_examples(elt) + [None]
802
args = typing.get_args(typ)
803
assert len(args) == 1
806
self._generate_examples(elt),
807
self._generate_examples(elt),
808
self._generate_examples(elt),
810
if origin is collections.abc.Sequence:
811
args = typing.get_args(typ)
812
assert len(args) == 1
813
examples = self._generate_examples(args[0])
814
return list(itertools.product(examples, examples)) + []
815
raise NotImplementedError(
816
f"testrunner cannot generate instanstance of type {typ}"
819
def test_supported_return_types_single_return(self):
820
for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
821
for example in self._generate_examples(typ):
824
@custom_ops.custom_op(f"{self.test_ns}::foo")
825
def foo(x: Tensor) -> typ:
826
raise NotImplementedError
828
@custom_ops.impl(f"{self.test_ns}::foo")
829
def foo_impl(x: Tensor) -> typ:
832
op = self.get_op(f"{self.test_ns}::foo")
833
result = op(torch.randn([]))
834
self.assertEqual(result, example, msg=f"{typ} {example}")
836
custom_ops._destroy(f"{self.test_ns}::foo")
838
def test_supported_return_types_multi_return(self):
839
for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
840
for example in self._generate_examples(typ):
843
@custom_ops.custom_op(f"{self.test_ns}::foo")
844
def foo(x: Tensor) -> Tuple[typ, typ]:
845
raise NotImplementedError
847
@custom_ops.impl(f"{self.test_ns}::foo")
848
def foo_impl(x: Tensor) -> Tuple[typ, typ]:
849
return (example, example)
851
op = self.get_op(f"{self.test_ns}::foo")
852
result = op(torch.randn([]))
853
expected = (example, example)
854
self.assertEqual(result, expected, msg=f"{typ} {example}")
856
custom_ops._destroy(f"{self.test_ns}::foo")
858
def test_supported_param_types(self):
859
for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES:
861
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
862
def foo(x: Tensor, y: typ) -> Tensor:
863
raise NotImplementedError
867
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"])
874
for example in self._generate_examples(typ):
875
op = self.get_op(f"{self.test_ns}::foo")
876
op(torch.randn([]), example)
877
self.assertEqual(yeet, example, msg=f"{typ} {example}")
880
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
882
def test_sequences(self):
883
# Sequence[int] gets automagically turned into int[] in the schema.
884
# This test checks that we actually do support arbitrary sequence types.
885
class MySequence(collections.abc.Sequence):
886
def __init__(self) -> None:
887
self._container = [1, 2, 3]
889
def __getitem__(self, idx):
890
return self._container[idx]
893
return len(self._container)
895
@custom_ops.custom_op(f"{self.test_ns}::foo")
896
def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
897
raise NotImplementedError
901
@custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu")
902
def foo_cpu(x, sizes):
905
# Dispatcher will normalize the sequence type into a List
906
self.assertEqual(sizes, [1, 2, 3])
911
op = self.get_op(f"{self.test_ns}::foo")
913
self.assertEqual(called, 1)
915
def test_unsupported_param_types(self):
916
# Not comprehensive (it doesn't need to be), just a check that our mechanism works
917
with self.assertRaisesRegex(ValueError, "unsupported type"):
919
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
920
def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
921
raise NotImplementedError
925
with self.assertRaisesRegex(ValueError, "unsupported type"):
926
# int[N] in Dispatcher is a bit wild, so we don't try to support it.
927
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
928
def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
929
raise NotImplementedError
933
with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"):
934
# test that we propose a correct and supported type.
935
@torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
936
def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
937
raise NotImplementedError
941
with self.assertRaises(ValueError) as cm:
943
@torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
944
def foo(x: Tensor, y: Tuple[int, float]) -> Tensor:
945
raise NotImplementedError
949
self.assertNotIn("example", str(cm.exception), "")
951
with self.assertRaisesRegex(ValueError, "unsupported type"):
953
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
954
def foo(x: Tensor, y: Callable) -> Tensor:
955
raise NotImplementedError
959
def test_supported_schemas(self):
960
# All of these should already be tested by PyTorch codegen
961
# (we share the same mechanism), but here's a sanity check.
963
"(Tensor x) -> Tensor",
964
"(Tensor x) -> Tensor y",
965
"(Tensor[] x) -> Tensor y",
966
"(Tensor x) -> (Tensor, Tensor)",
967
"(Tensor x) -> (Tensor y, Tensor z)",
968
"(Tensor x) -> (Tensor y, Tensor z)",
971
"(Tensor x, Tensor w) -> (Tensor y, Tensor z)",
972
"(Tensor x, Tensor w) -> (Tensor, Tensor)",
973
"(Tensor x, Tensor w) -> Tensor",
974
"(Tensor? x, Tensor w) -> Tensor",
975
"(Tensor? x, Tensor[] w) -> Tensor",
976
"(Tensor x, int[] w) -> Tensor",
977
"(Tensor x, SymInt[] w) -> Tensor",
978
"(Tensor x, Scalar w) -> Tensor",
979
"(Tensor x, float w) -> Tensor",
980
"(Tensor x, float? w) -> Tensor",
981
"(Tensor x, bool[] w) -> Tensor",
984
for schema in schemas:
985
custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema)
986
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
987
for schema in other_schemas:
988
custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema)
989
custom_ops._destroy(f"{TestCustomOp.test_ns}::bar")
991
def test_reserved_ns(self):
992
from torch._custom_op.impl import RESERVED_NS
994
for ns in RESERVED_NS:
995
with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
996
custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor")
998
with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
1000
@custom_ops.custom_op(f"{ns}::foo2")
1001
def foo2(x: torch.Tensor) -> torch.Tensor:
1002
raise NotImplementedError
1004
def test_private_ctor(self):
1005
with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
1006
CustomOp(None, None, None, None, None)
1008
def test_lifetime(self):
1009
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1010
def foo(x: torch.Tensor) -> torch.Tensor:
1011
raise NotImplementedError
1013
custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
1015
# We can't define an op multiple times,
1016
with self.assertRaisesRegex(RuntimeError, "multiple times"):
1018
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1019
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
1020
raise NotImplementedError
1022
# Unless we delete the original op.
1023
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1026
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1027
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
1028
raise NotImplementedError
1030
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1032
def test_autograd_notimplemented(self):
1033
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1034
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
1035
raise NotImplementedError
1037
x = torch.randn(3, requires_grad=True)
1038
op = self.get_op(f"{self.test_ns}::foo")
1039
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1041
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1044
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1045
def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
1046
raise NotImplementedError
1048
x = torch.randn(3, requires_grad=True)
1050
op = self.get_op(f"{self.test_ns}::foo")
1051
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1053
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1056
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1057
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1058
raise NotImplementedError
1060
x = torch.randn(3, requires_grad=True)
1062
op = self.get_op(f"{self.test_ns}::foo")
1063
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1065
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1067
def test_autograd_notimplemented_gradmode(self):
1068
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1069
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1070
raise NotImplementedError
1072
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1076
x = torch.randn(3, requires_grad=True)
1078
op = self.get_op(f"{self.test_ns}::foo")
1079
with torch.no_grad():
1080
# Shouldn't raise, because we are in no_grad
1083
def test_impl_cpu(self):
1084
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1085
def foo(x: torch.Tensor) -> torch.Tensor:
1086
raise NotImplementedError
1088
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1093
op = self.get_op(f"{self.test_ns}::foo")
1095
self.assertEqual(result, foo_cpu(x))
1097
def test_impl_invalid_devices(self):
1098
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1099
def foo(x: torch.Tensor) -> torch.Tensor:
1100
raise NotImplementedError
1105
from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
1107
for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
1108
# Smoke test: should not raise error
1109
custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)(
1113
# Not supported by this API: we can either support them in the future
1114
# or provide some other CustomOp.def_* function. This depends on how
1115
# common the use cases are.
1116
for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]:
1117
with self.assertRaisesRegex(ValueError, "we only support device_type"):
1119
f"{TestCustomOp.test_ns}::foo", device_types=invalid_type
1122
def test_backward_partially_registered(self):
1123
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1124
def foo(x: torch.Tensor) -> torch.Tensor:
1125
raise NotImplementedError
1127
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1131
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1132
def foo_backward(ctx, saved, grad):
1133
return grad * saved.cos()
1135
x = torch.randn([], requires_grad=True)
1136
op = self.get_op(f"{self.test_ns}::foo")
1137
with self.assertRaisesRegex(
1138
RuntimeError, "unable to find a 'save_for_backward'"
1143
def test_save_for_backward_inputs_are_namedtuple(self):
1144
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1145
def foo(x: torch.Tensor) -> torch.Tensor:
1146
raise NotImplementedError
1148
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1154
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1155
def foo_save_for_backward(inputs, output):
1158
self.assertTrue(isinstance(inputs, tuple))
1159
self.assertEqual(list(inputs._asdict().keys()), ["x"])
1162
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1163
def foo_backward(ctx, saved, grad):
1164
return {"x": grad * saved.cos()}
1166
x = torch.randn([], requires_grad=True)
1167
op = self.get_op(f"{self.test_ns}::foo")
1169
self.assertEqual(hit, 1)
1171
self.assertEqual(hit, 1)
1173
def test_backward_returns_dict(self):
1174
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1175
def foo(x: torch.Tensor) -> torch.Tensor:
1176
raise NotImplementedError
1178
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1182
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1183
def foo_save_for_backward(inputs, output):
1186
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1187
def foo_backward(ctx, saved, grad):
1188
return grad * saved.cos()
1190
x = torch.randn([], requires_grad=True)
1191
op = self.get_op(f"{self.test_ns}::foo")
1193
with self.assertRaisesRegex(RuntimeError, "to be a dict"):
1196
def test_backward_dict_invalid_keys(self):
1197
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1198
def foo(x: torch.Tensor) -> torch.Tensor:
1199
raise NotImplementedError
1201
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1205
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1206
def foo_save_for_backward(inputs, output):
1209
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1210
def foo_backward(ctx, saved, grad):
1211
return {"x": grad * saved.cos(), "y": None}
1213
x = torch.randn([], requires_grad=True)
1214
op = self.get_op(f"{self.test_ns}::foo")
1216
with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
1219
def test_backward_dict_grad_for_nontensor(self):
1220
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1221
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1222
raise NotImplementedError
1224
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1225
def foo_impl(x, dim):
1228
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1229
def foo_save_for_backward(inputs, output):
1232
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1233
def foo_backward(ctx, saved, grad):
1234
return {"x": grad * saved.cos(), "dim": None}
1236
x = torch.randn([], requires_grad=True)
1237
op = self.get_op(f"{self.test_ns}::foo")
1239
with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
1242
def test_backward_dict_requires_keys_for_input_tensors(self):
1243
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1244
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1245
raise NotImplementedError
1247
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1251
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1252
def foo_save_for_backward(inputs, output):
1255
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1256
def foo_backward(ctx, saved, grad):
1257
return {"x": grad * saved.cos()}
1259
x = torch.randn([], requires_grad=True)
1260
op = self.get_op(f"{self.test_ns}::foo")
1262
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1265
def test_backward_dict_requires_keys_for_input_optional_tensors(self):
1266
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1267
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
1268
raise NotImplementedError
1270
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1274
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1275
def foo_save_for_backward(inputs, output):
1278
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1279
def foo_backward(ctx, saved, grad):
1280
return {"x": grad * saved.cos()}
1282
x = torch.randn([], requires_grad=True)
1283
op = self.get_op(f"{self.test_ns}::foo")
1285
with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1288
def test_backward_grads_are_tensor_or_none(self):
1289
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1290
def foo(x: torch.Tensor) -> torch.Tensor:
1291
raise NotImplementedError
1293
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1297
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1298
def foo_save_for_backward(inputs, output):
1301
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1302
def foo_backward(ctx, saved, grad):
1303
return {"x": (grad * saved.cos(),)}
1305
x = torch.randn([], requires_grad=True)
1306
op = self.get_op(f"{self.test_ns}::foo")
1308
with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"):
1311
def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
1312
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1313
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1314
raise NotImplementedError
1316
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1320
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1321
def foo_save_for_backward(inputs, output):
1324
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1325
def foo_backward(ctx, saved, grad):
1326
return {"xs": [grad * saved.cos(), None]}
1328
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1329
op = self.get_op(f"{self.test_ns}::foo")
1331
with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
1334
def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
1335
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1336
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1337
raise NotImplementedError
1339
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1343
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1344
def foo_save_for_backward(inputs, output):
1347
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1348
def foo_backward(ctx, saved, grad):
1349
return {"xs": [grad * saved.cos(), None, (None,)]}
1351
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1352
op = self.get_op(f"{self.test_ns}::foo")
1354
with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
1357
def test_backward_tensorlist_input_requires_list_grads(self):
1358
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1359
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1360
raise NotImplementedError
1362
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1366
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1367
def foo_save_for_backward(inputs, output):
1370
@custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1371
def foo_backward(ctx, saved, grad):
1374
xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1375
op = self.get_op(f"{self.test_ns}::foo")
1377
with self.assertRaisesRegex(RuntimeError, "list of gradients"):
1380
def test_backward_output_differentiability_type(self):
1381
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1382
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1383
raise NotImplementedError
1385
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1387
@custom_ops.impl_backward(
1388
f"{TestCustomOp.test_ns}::foo", output_differentiability=True
1390
def foo_backward(ctx, saved, grad):
1393
def test_backward_output_differentiability_numel(self):
1394
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1395
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
1396
raise NotImplementedError
1398
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1400
@custom_ops.impl_backward(
1401
f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
1403
def foo_backward(ctx, saved, grad):
1406
def test_backward_output_differentiability_tensorlist(self):
1407
@custom_ops.custom_op(f"{self.test_ns}::foo")
1408
def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
1409
raise NotImplementedError
1411
@custom_ops.impl(f"{self.test_ns}::foo")
1413
return [x.clone(), x.clone()], x.clone()
1415
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1416
def foo_save_for_backward(inputs, output):
1419
@custom_ops.impl_backward(
1420
f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
1422
def foo_backward(ctx, saved, grad_lst, grad):
1425
op = self.get_op(f"{self.test_ns}::foo")
1426
x = torch.randn(3, requires_grad=True)
1428
self.assertFalse(a.requires_grad)
1429
self.assertFalse(b.requires_grad)
1430
self.assertTrue(c.requires_grad)
1432
def test_backward_output_differentiability_non_tensor(self):
1433
@custom_ops.custom_op(f"{self.test_ns}::foo")
1434
def foo(x: Tensor) -> Tuple[Tensor, int]:
1435
raise NotImplementedError
1437
@custom_ops.impl(f"{self.test_ns}::foo")
1441
@custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1442
def foo_save_for_backward(inputs, output):
1445
@custom_ops.impl_backward(
1446
f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
1448
def foo_backward(ctx, saved, grad0, grad1):
1451
op = self.get_op(f"{self.test_ns}::foo")
1452
x = torch.randn(3, requires_grad=True)
1453
with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
1456
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
1457
def test_impl_separate(self):
1458
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1459
def foo(x: torch.Tensor) -> torch.Tensor:
1460
raise NotImplementedError
1462
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1466
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda")
1471
op = self.get_op(f"{self.test_ns}::foo")
1473
self.assertEqual(result, foo_cpu(x))
1476
op = self.get_op(f"{self.test_ns}::foo")
1478
self.assertEqual(result, foo_cuda(x_cuda))
1480
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
1481
def test_impl_multiple(self):
1482
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1483
def foo(x: torch.Tensor) -> torch.Tensor:
1484
raise NotImplementedError
1486
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1490
op = self.get_op(f"{self.test_ns}::foo")
1493
self.assertEqual(result, foo_impl(x))
1497
self.assertEqual(result, foo_impl(x_cuda))
1499
def test_impl_abstract_overload(self):
1501
lib.define("sin.blah(Tensor x) -> Tensor")
1503
torch.library.impl_abstract(
1504
f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
1507
op = self.ns().sin.blah
1508
x = torch.randn(3, device="meta")
1511
def test_impl_meta(self):
1512
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1513
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1514
raise NotImplementedError
1516
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1517
def foo_meta(x, dim):
1518
output_shape = list(x.shape)
1519
del output_shape[dim]
1520
return x.new_empty(output_shape)
1522
x = torch.randn(2, 3, device="meta")
1523
op = self.get_op(f"{self.test_ns}::foo")
1525
self.assertEqual(result.shape, foo_meta(x, 1).shape)
1527
def test_duplicate_impl(self):
1528
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1529
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1530
raise NotImplementedError
1532
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1533
def foo_meta(x, dim):
1534
output_shape = list(x.shape)
1535
del output_shape[dim]
1536
return x.new_empty(output_shape)
1538
with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
1540
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1541
def foo_meta2(x, dim):
1542
output_shape = list(x.shape)
1543
del output_shape[dim]
1544
return x.new_empty(output_shape)
1546
def test_new_data_dependent_symint(self):
1547
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1548
def foo(x: torch.Tensor) -> torch.Tensor:
1549
raise NotImplementedError
1551
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1553
ctx = torch.library.get_ctx()
1554
r = ctx.new_dynamic_size(min=1)
1555
with self.assertRaisesRegex(ValueError, "greater than or equal to 0"):
1556
ctx.new_dynamic_size(min=-1)
1557
with self.assertRaisesRegex(ValueError, "SymInt"):
1558
ctx.new_dynamic_size(max=x.numel())
1559
# NB: You must return dynamic sizes!
1560
return x.new_empty(r)
1562
x = torch.randn(2, 3, device="cpu")
1563
op = self.get_op(f"{self.test_ns}::foo")
1564
make_fx(op, tracing_mode="symbolic")(x)
1566
def test_meta_for_data_dependent_shape_operation(self):
1567
x = torch.randn(10, device="meta")
1568
with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
1571
def test_basic_make_fx(self):
1572
# More serious tests are in our CustomOp opinfo db,
1573
# this one is just a sanity check.
1574
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1575
def foo(x: torch.Tensor) -> torch.Tensor:
1576
raise NotImplementedError
1578
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1583
op = self.get_op(f"{self.test_ns}::foo")
1584
gm = make_fx(op, tracing_mode="symbolic")(x)
1585
self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code)
1587
def test_not_implemented_error(self):
1588
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1589
def foo(x: torch.Tensor) -> torch.Tensor:
1590
raise NotImplementedError
1593
op = self.get_op(f"{self.test_ns}::foo")
1594
with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
1597
x = torch.randn(3, device="meta")
1598
with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"):
1601
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
1602
def bar(sizes: Sequence[int]) -> torch.Tensor:
1603
raise NotImplementedError
1605
op = self.get_op(f"{self.test_ns}::bar")
1606
with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
1609
def test_data_dependent_basic(self):
1610
x = torch.randn(5, 5)
1611
gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x)
1612
self.assertTrue("nonzero" in gm.code)
1614
def test_data_dependent_fake_tracing(self):
1615
x = torch.randn(5, 5)
1616
# We've updated to attempt to use unbacked symints even for fake
1618
make_fx(numpy_nonzero, tracing_mode="fake")(x)
1620
def test_symints(self):
1622
return torch.ops._torch_testing.numpy_view_copy(x, x.shape)
1624
x = torch.randn(2, 3, 4)
1625
gm = make_fx(f, tracing_mode="symbolic")(x)
1627
self.assertEqual(result, f(x))
1628
self.assertExpectedInline(
1631
def forward(self, x_1):
1632
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1633
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
1634
sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
1635
numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
1636
return numpy_view_copy""", # noqa: B950
1639
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
1640
def test_data_dependent_compile(self):
1641
import torch._dynamo.testing
1642
from torch._dynamo.utils import counters
1645
cnt = torch._dynamo.testing.CompileCounter()
1647
@torch.compile(backend=cnt)
1649
return numpy_nonzero(x.clone()).clone()
1653
self.assertEqual(len(counters["graph_break"]), 1)
1654
self.assertEqual(next(iter(counters["graph_break"].values())), 1)
1655
self.assertExpectedInline(
1656
next(iter(counters["graph_break"].keys())).replace(";", "\n"),
1658
dynamic shape operator: _torch_testing.numpy_nonzero.default
1659
to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
1662
# pre-existing problem: torch.compile(dynamic=True) will, by default,
1663
# graph break on data-dependent operations. Eventually we'll make it so
1664
# that it never graph breaks on data-dependent operations.
1665
@unittest.expectedFailure
1666
def test_data_dependent_nms_dynamic_compile(self):
1667
import torch._dynamo.testing
1668
from torch._dynamo.utils import counters
1671
cnt = torch._dynamo.testing.CompileCounter()
1673
@torch.compile(backend=cnt, dynamic=True)
1675
return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone()
1677
f(torch.randn(20, 4), torch.randn(20), 0.1)
1679
self.assertEqual(len(counters["graph_break"]), 0)
1681
def test_impl_on_existing_op(self):
1683
lib.define("foo(Tensor x) -> Tensor")
1684
qualname = f"{self.test_ns}::foo"
1686
@torch._custom_ops.impl(qualname)
1690
op = self.get_op(qualname)
1693
self.assertEqual(result, x.sin())
1696
"key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"]
1698
def test_impl_on_existing_op_with_cpu_registration(self, key):
1700
lib.define("foo(Tensor x) -> Tensor")
1701
qualname = f"{self.test_ns}::foo"
1706
lib.impl("foo", foo_impl, key)
1707
op = self.get_op(qualname)
1709
with self.assertRaisesRegex(RuntimeError, "already has an implementation"):
1710
custom_ops.impl(qualname, func=foo_impl)
1712
def test_abstract_impl_on_existing_op(self):
1714
lib.define("foo(Tensor x) -> Tensor")
1715
qualname = f"{self.test_ns}::foo"
1717
@torch.library.impl_abstract(qualname, lib=self.lib())
1721
op = self.get_op(qualname)
1722
with torch._subclasses.FakeTensorMode():
1725
self.assertEqual(result.shape, x.shape)
1726
self.assertEqual(result.stride(), x.stride())
1728
def test_abstract_impl_on_existing_op_with_meta(self):
1730
lib.define("foo(Tensor x) -> Tensor")
1731
qualname = f"{self.test_ns}::foo"
1736
lib.impl("foo", foo_impl, "Meta")
1737
op = self.get_op(qualname)
1739
with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
1740
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1742
def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
1744
lib.define("foo(Tensor x) -> Tensor")
1745
qualname = f"{self.test_ns}::foo"
1750
lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
1751
op = self.get_op(qualname)
1753
with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
1754
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1756
def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
1758
lib.define("foo(Tensor x) -> Tensor")
1759
qualname = f"{self.test_ns}::foo"
1764
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
1765
op = self.get_op(qualname)
1767
torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
1768
with torch._subclasses.FakeTensorMode():
1771
self.assertEqual(result.shape, ())
1773
def _test_backward_impl_raises(self, qualname, err_regex):
1774
with self.assertRaisesRegex(RuntimeError, err_regex):
1776
@custom_ops.impl_save_for_backward(qualname)
1780
with self.assertRaisesRegex(RuntimeError, err_regex):
1782
@custom_ops.impl_backward(qualname)
1786
def test_backward_impl_on_existing_op_incorrect_schema_views(self):
1788
lib.define("foo(Tensor(a) x) -> Tensor(a)")
1789
qualname = f"{self.test_ns}::foo"
1790
self._test_backward_impl_raises(qualname, "operator that returns views")
1792
def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
1794
lib.define("foo(Tensor(a!) x) -> Tensor")
1795
qualname = f"{self.test_ns}::foo"
1796
self._test_backward_impl_raises(qualname, "non-functional")
1798
def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
1800
lib.define("foo(Tensor x) -> ()")
1801
qualname = f"{self.test_ns}::foo"
1802
self._test_backward_impl_raises(qualname, "no returns")
1804
def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
1806
lib.define("foo(Tensor x) -> Tensor")
1807
qualname = f"{self.test_ns}::foo"
1808
lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
1809
self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
1811
@parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
1812
def test_backward_impl_on_existing_op_with_key(self, key):
1814
lib.define("foo(Tensor x) -> Tensor")
1815
qualname = f"{self.test_ns}::foo"
1816
lib.impl("foo", lambda x: x.sin().cos(), key)
1817
self._test_backward_impl_raises(qualname, key)
1819
def test_is_functional_schema(self):
1821
"foo(Tensor x) -> Tensor": True,
1822
"foo(Tensor(a) x) -> Tensor": True,
1823
"foo(Tensor(a!) x) -> Tensor": False,
1824
"foo(Tensor(a) x) -> Tensor(a)": False,
1825
"foo(Tensor x) -> ()": False,
1827
for schema_str, expected in tests.items():
1828
res = torch._library.utils.is_functional_schema(schema_str)
1829
self.assertEqual(res, expected)
1831
from torchgen.model import FunctionSchema
1833
schema = FunctionSchema.parse(schema_str)
1834
res = torch._library.utils.is_functional_schema(schema)
1835
self.assertEqual(res, expected)
1837
schema = torch._C.parse_schema(schema_str)
1838
res = torch._library.utils.is_functional_schema(schema)
1839
self.assertEqual(res, expected)
1841
def test_incorrect_schema_types(self):
1842
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1843
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1844
lib.define("foo12(Tensor a) -> asdfasdf")
1845
with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1846
lib.define("foo12(asdf a) -> Tensor")
1847
with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
1848
lib.define("foo12(int64_t a) -> Tensor")
1849
with self.assertRaisesRegex(RuntimeError, "Use `float`"):
1850
lib.define("foo12(double a) -> Tensor")
1852
def test_is_tensorlist_like_type(self):
1855
torch.ops.aten.where.default._schema.returns[0].type,
1857
torch.ops.aten.index.Tensor._schema.arguments[1].type,
1859
torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type,
1861
torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type,
1865
torch.ops.aten.sin.default._schema.arguments[0].type,
1867
torch.ops.aten.sum.dim_IntList._schema.arguments[1].type,
1869
for a in tensorlists:
1870
self.assertTrue(torch._library.utils.is_tensorlist_like_type(a))
1871
for a in non_tensorlists:
1872
self.assertFalse(torch._library.utils.is_tensorlist_like_type(a))
1874
def test_backward_impl_on_existing_op(self):
1876
lib.define("foo(Tensor x) -> Tensor")
1877
qualname = f"{self.test_ns}::foo"
1879
@custom_ops.impl(qualname)
1881
with torch.no_grad():
1884
@custom_ops.impl_save_for_backward(qualname)
1885
def foo_save_for_backward(inputs, output):
1888
@custom_ops.impl_backward(qualname)
1889
def foo_backward(ctx, saved, grad_out):
1890
return {"x": grad_out * saved.cos()}
1892
op = self.get_op(qualname)
1893
x = torch.randn([], requires_grad=True)
1895
(gx,) = torch.autograd.grad(y, x)
1896
self.assertEqual(gx, x.cos())
1901
subtest(torch.Tag.pointwise, "single"),
1902
subtest((torch.Tag.pointwise,), "tuple"),
1903
subtest([torch.Tag.pointwise], "list"),
1906
def test_define_with_tags(self, tags):
1908
tags = (torch.Tag.pointwise,)
1909
torch.library.define(
1910
f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags
1912
actual = self.ns().foo.default.tags
1913
self.assertTrue(isinstance(actual, list))
1914
self.assertEqual(actual, list(tags))
1916
def test_builtin_aten_ops_are_pt2_compliant(self):
1917
for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
1918
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1920
def test_builtin_torchscript_ops(self):
1921
for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]:
1922
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1924
def test_autogen_aten_ops_are_pt2_compliant(self):
1925
for op in [torch.ops.aten.fill.Tensor_out]:
1926
self.assertIn(torch.Tag.generated, op.tags)
1927
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1929
def test_resolve_packet(self):
1931
result = torch._C._jit_resolve_packet("aten::sum", x)
1932
self.assertEqual(result, "default")
1934
result = torch._C._jit_resolve_packet("aten::sum", x, dim=1)
1935
self.assertEqual(result, "dim_IntList")
1937
with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1938
result = torch._C._jit_resolve_packet("aten::sum", x, x, x)
1940
def test_define_bad_schema(self):
1942
with self.assertRaisesRegex(ValueError, "expected schema to look like"):
1943
torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor")
1945
def test_define_and_impl(self):
1947
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1949
@torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib)
1951
return torch.from_numpy(np.sin(x.numpy()))
1954
y = self.ns().foo(x)
1955
assert torch.allclose(y, x.sin())
1957
def test_define_validation(self):
1958
with self.assertRaisesRegex(ValueError, "namespace"):
1959
torch.library.define("foo", "(Tensor x) -> Tensor")
1961
def test_legacy_define(self):
1964
@torch.library.define(lib, "foo(Tensor x) -> Tensor")
1966
return torch.from_numpy(np.sin(x.numpy()))
1969
y = self.ns().foo(x)
1970
assert torch.allclose(y, x.sin())
1972
def test_impl_function(self):
1974
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1977
return torch.from_numpy(np.sin(x.numpy()))
1979
torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib)
1981
y = self.ns().foo(x)
1982
assert torch.allclose(y, x.sin())
1984
def test_legacy_impl(self):
1986
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1988
@torch.library.impl(lib, "foo", "CPU")
1990
return torch.from_numpy(np.sin(x.numpy()))
1993
y = self.ns().foo(x)
1994
assert torch.allclose(y, x.sin())
1996
def test_defined_in_python(self):
1997
self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
1998
self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
2001
torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2003
self.assertTrue(ns.foo.default._defined_in_python)
2005
torch.library.define(
2006
"{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
2008
self.assertTrue(ns.bar.overload._defined_in_python)
2010
def _test_impl_device(self, name, types, device):
2012
torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)
2014
@torch.library.impl(f"{self.test_ns}::{name}", types)
2016
x_np = x.cpu().numpy()
2017
y = torch.from_numpy(np.sin(x_np))
2018
return y.to(device=x.device)
2020
x = torch.randn(3, device=device)
2021
y = getattr(self.ns(), name)(x)
2022
assert torch.allclose(y, x.sin())
2024
def test_impl_device_cpu(self):
2025
self._test_impl_device("foo1", "default", "cpu")
2026
self._test_impl_device("foo2", ["cpu"], "cpu")
2027
self._test_impl_device("foo3", ["cpu", "cuda"], "cpu")
2029
@unittest.skipIf(not TEST_CUDA, "requires cuda")
2030
def test_impl_device_cuda(self):
2031
self._test_impl_device("foo4", "default", "cuda")
2032
self._test_impl_device("foo5", ["cuda"], "cuda")
2033
self._test_impl_device("foo6", ["cpu", "cuda"], "cuda")
2035
def test_impl_device_function(self):
2037
torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2040
x_np = x.cpu().numpy()
2041
y = torch.from_numpy(np.sin(x_np))
2042
return y.to(device=x.device)
2044
torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib)
2046
y = self.ns().foo(x)
2047
assert torch.allclose(y, x.sin())
2049
def test_impl_device_invalid(self):
2050
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
2051
torch.library.impl("blah::blah", "somethingsomething")
2053
def test_autograd_function_backed_op(self):
2055
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
2056
static constexpr bool is_traceable = true;
2058
static torch::Tensor forward(
2059
torch::autograd::AutogradContext* ctx,
2060
const torch::Tensor& x) {
2064
static torch::autograd::variable_list backward(
2065
torch::autograd::AutogradContext *ctx,
2066
torch::autograd::variable_list grad_output) {
2071
torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
2072
return CustomOpAutogradFunction::apply(x);
2075
TORCH_LIBRARY(mylib, m) {
2076
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
2080
module = torch.utils.cpp_extension.load_inline(
2082
cpp_sources=cpp_source,
2083
functions="custom_op_backed_by_autograd_fn",
2087
x = torch.ones(2, 2, requires_grad=True)
2088
temp = x.clone().detach()
2089
out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
2092
self.assertEqual(x.grad, temp)
2095
def op_with_incorrect_schema(testcase, name):
2096
lib = testcase.lib()
2097
lib.define(f"{name}(Tensor x) -> Tensor")
2098
qualname = f"{testcase.test_ns}::{name}"
2099
lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd")
2100
return testcase.get_op(qualname)
2103
class MiniOpTest(CustomOpTestCaseBase):
2104
test_ns = "mini_op_test"
2106
def _init_op_delayed_backward_error(self):
2107
name = "delayed_error"
2108
qualname = f"{self.test_ns}::{name}"
2110
lib.define(f"{name}(Tensor x) -> Tensor")
2111
lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd")
2112
op = self.get_op(qualname)
2114
class Op(torch.autograd.Function):
2116
def forward(ctx, x):
2117
with torch._C._AutoDispatchBelowAutograd():
2121
def backward(ctx, grad):
2122
raise NotImplementedError
2124
def autograd_impl(x):
2127
lib.impl(name, autograd_impl, "Autograd")
2130
def _init_op_with_no_abstract_impl(self):
2131
name = "no_abstract"
2132
qualname = f"{self.test_ns}::{name}"
2134
lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,))
2135
lib.impl(name, lambda x: x.clone(), "CPU")
2136
return torch._library.utils.lookup_op(qualname)
2140
self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl()
2141
self._op_delayed_backward_error = self._init_op_delayed_backward_error()
2143
@optests.dontGenerateOpCheckTests("Testing this API")
2144
def test_dont_generate(self):
2145
op = op_with_incorrect_schema(self, "incorrect_schema")
2150
x = torch.randn(2, 3, requires_grad=True)
2151
y = torch.randn(3, 5)
2152
result = torch.ops.aten.mm.default(x, y)
2153
self.assertEqual(result, x @ y)
2155
def test_mm_meta(self):
2156
x = torch.randn(2, 3, requires_grad=True, device="meta")
2157
y = torch.randn(3, 5, device="meta")
2158
result = torch.ops.aten.mm.default(x, y)
2159
self.assertEqual(result.shape, (x @ y).shape)
2161
def test_mm_fake(self):
2162
with torch._subclasses.fake_tensor.FakeTensorMode():
2163
x = torch.randn(2, 3, requires_grad=True, device="cpu")
2164
y = torch.randn(3, 5, device="cpu")
2165
result = torch.ops.aten.mm.default(x, y)
2166
self.assertEqual(result.shape, (x @ y).shape)
2168
def test_mm_errors(self):
2169
x = torch.randn(2, 3, requires_grad=True)
2170
y = torch.randn(4, 5)
2171
with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"):
2172
result = torch.ops.aten.mm.default(x, y)
2174
def test_nonzero(self):
2175
x = torch.tensor([0, 1, 2, 0, 0])
2176
y = torch.ops.aten.nonzero.default(x)
2177
self.assertEqual(y, torch.tensor([[1], [2]]))
2179
def test_inplace(self):
2182
y = torch.ops.aten.sin_(x)
2183
self.assertEqual(x, x_clone.sin())
2185
def test_incorrect_schema(self):
2186
op = op_with_incorrect_schema(self, "incorrect_schema")
2190
def test_no_abstract(self):
2191
op = self._op_with_no_abstract_impl
2195
def test_delayed_error(self):
2196
op = self._op_delayed_backward_error
2197
x = torch.randn([], requires_grad=True)
2199
with self.assertRaises(NotImplementedError):
2202
def test_delayed_error_no_requires_grad(self):
2203
op = self._op_delayed_backward_error
2208
class TestCustomOpAPI(TestCase):
2209
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2210
def test_basic(self):
2211
@torch.library.custom_op("_torch_testing::add", mutates_args=())
2212
def add(x: Tensor, y: float) -> Tensor:
2213
x_np = x.numpy(force=True)
2215
return torch.from_numpy(out_np).to(x.device)
2220
self.assertEqual(z, x + y)
2224
@add.register_kernel("cpu")
2230
return torch.from_numpy(out_np)
2233
self.assertEqual(z, x + y)
2234
self.assertTrue(cpu_called)
2236
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2237
def test_no_grad_skips_autograd(self):
2238
@torch.library.custom_op("_torch_testing::add", mutates_args=())
2239
def add(x: Tensor, y: float) -> Tensor:
2240
x_np = x.numpy(force=True)
2242
return torch.from_numpy(out_np).to(x.device)
2246
def setup_context(ctx, inputs, output):
2250
def backward(ctx, grad):
2251
raise AssertionError("should not be reached")
2253
add.register_autograd(backward, setup_context=setup_context)
2255
x = torch.randn(3, requires_grad=True)
2256
with torch.no_grad():
2258
self.assertEqual(called, 0)
2259
self.assertEqual(y, x + 2.0)
2261
x.requires_grad_(False)
2263
self.assertEqual(called, 0)
2264
self.assertEqual(y, x + 2.0)
2266
x = torch.randn(3, requires_grad=True)
2268
self.assertEqual(called, 1)
2269
self.assertEqual(y, x + 2.0)
2271
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2272
def test_manual_schema(self):
2273
@torch.library.custom_op(
2274
"_torch_testing::add",
2276
schema="(Tensor x, float y) -> Tensor",
2279
x_np = x.numpy(force=True)
2281
return torch.from_numpy(out_np).to(x.device)
2286
self.assertEqual(z, x + y)
2288
@torch.library.custom_op(
2289
"_torch_testing::sin_",
2291
schema="(Tensor(a!) x) -> ()",
2295
np.sin(x_np, out=x_np)
2300
self.assertEqual(x, expected)
2302
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2303
def test_kwarg_only_tensors(self):
2304
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2306
@torch.library.custom_op("_torch_testing::foo", mutates_args=())
2307
def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor:
2310
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2312
@torch.library.custom_op("_torch_testing::foo", mutates_args=())
2313
def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor:
2316
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2318
@torch.library.custom_op("_torch_testing::foo", mutates_args=())
2319
def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor:
2322
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2323
lib.define("foo(Tensor x, *, Tensor y) -> Tensor")
2324
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2325
torch.library.register_autograd(
2326
"_torch_testing::foo",
2328
setup_context=lambda ctx, inputs, keyword_only_inputs, output: None,
2331
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2332
torch.library.register_vmap(
2333
"_torch_testing::foo",
2334
lambda info, in_dims, x, *, y: (x, 0),
2337
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2338
def test_register_autograd_kwargonly_low_level(self):
2339
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2340
lib.define("foo(Tensor x, *, float y) -> Tensor")
2343
def foo_impl(x, *, y):
2346
lib.impl("foo", foo_impl, "CPU")
2348
def backward(ctx, grad):
2353
def setup_context(ctx, inputs, keyword_only_inputs, output):
2354
assert tuple(keyword_only_inputs.keys()) == ("y",)
2355
ctx.y = keyword_only_inputs["y"]
2357
torch.library.register_autograd(
2358
"_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2361
x = torch.randn(3, requires_grad=True)
2362
torch.ops._torch_testing.foo(x, y=3.14).sum().backward()
2363
self.assertTrue(called)
2364
self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14]))
2366
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2367
def test_register_autograd_defaults(self):
2368
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2369
lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
2371
def foo_impl(w, x=2, *, y=3, z):
2372
return w * x * y * z
2374
lib.impl("foo", foo_impl, "CPU")
2378
def backward(ctx, grad):
2383
def setup_context(ctx, inputs, keyword_only_inputs, output):
2384
assert len(inputs) == 2
2385
assert inputs[1] == 2
2386
assert keyword_only_inputs == {"y": 3, "z": 42}
2387
ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1]
2389
torch.library.register_autograd(
2390
"_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2393
w = torch.randn(3, requires_grad=True)
2394
torch.ops._torch_testing.foo(w, z=42).sum().backward()
2395
self.assertTrue(called)
2396
self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42))
2398
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2399
def test_manual_schema_error(self):
2400
with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"):
2402
@torch.library.custom_op(
2403
"_torch_testing::sin_",
2405
schema="(Tensor(a!) x) -> ()",
2409
np.sin(x_np, out=x_np)
2411
def test_supports_tensorlist(self):
2412
@torch._library.autograd.supports_tensorlist
2413
class Stack(torch.autograd.Function):
2415
def forward(ctx, xs):
2416
ctx.num_xs = len(xs)
2417
return torch.stack(xs)
2420
def backward(ctx, grad):
2421
expected = ([True] * ctx.num_xs,)
2422
self.assertEqual(ctx.needs_input_grad, expected)
2423
return list(grad.unbind(0))
2425
# call two applys, do a backward on the first
2427
return torch.randn([], requires_grad=True)
2429
xs0 = [t(), t(), t()]
2430
xs1 = [t(), t(), t(), t()]
2431
y0 = Stack.apply(xs0)
2432
y1 = Stack.apply(xs1)
2433
grads = torch.autograd.grad(y0.sum(), xs0)
2434
self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2436
# call one apply, do multiple backwards
2437
xs = [t(), t(), t()]
2439
_ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2440
_ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2441
grads = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2442
self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2444
# error: on access forward, backward directly
2445
with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"):
2446
Stack.forward(None, xs)
2447
with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"):
2448
Stack.backward(None, xs)
2450
# the recursive case
2451
@torch._library.autograd.supports_tensorlist
2452
class Foo(torch.autograd.Function):
2454
def forward(ctx, xs):
2456
return Foo.apply(xs[1:])
2457
ctx.len_xs = len(xs)
2461
def backward(ctx, grad):
2462
result = [None] * ctx.len_xs
2463
result[-1] = grad.cos()
2467
result = Foo.apply(xs)
2468
expected = xs[-1].sin()
2469
self.assertEqual(result, expected)
2471
# recursive on backward
2472
@torch._library.autograd.supports_tensorlist
2473
class Bar(torch.autograd.Function):
2475
def forward(ctx, xs):
2476
return [xs[i] + i for i in range(len(xs))]
2479
def backward(ctx, grads):
2480
f1 = Bar.apply(grads[:2])
2481
f2 = Bar.apply(grads[2:])
2484
xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)]
2487
result = [xi.grad for xi in xs]
2488
self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0))
2490
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2491
def test_default_values(self):
2494
@torch.library.custom_op("_torch_testing::f", mutates_args=())
2497
a: Optional[int] = None,
2502
f: torch.dtype = torch.float,
2503
g: torch.dtype = torch.float32,
2504
h: torch.dtype = torch.int,
2505
i: torch.device = torch.device("cpu:0"),
2506
j: torch.device = "cpu",
2508
defaults.extend([a, b, c, d, e, f, g, h, i, j])
2524
torch.device("cpu:0"),
2530
for arg in torch.ops._torch_testing.f.default._schema.arguments
2532
# enum values taken from c10/core/ScalarType.h
2549
torch.device("cpu:0"),
2550
torch.device("cpu"),
2554
def test_mutated_error(self):
2555
with self.assertRaisesRegex(
2556
ValueError, r".*{'y'} in mutates_args were not found"
2559
@torch.library.custom_op(
2560
"_torch_testing::numpy_sin_inplace",
2564
def numpy_sin_inplace(x: Tensor) -> None:
2566
np.sin(x_np, out=x_np)
2568
def test_mutated(self):
2569
@torch.library.custom_op(
2570
"_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu"
2572
def numpy_sin_inplace(x: Tensor) -> None:
2574
np.sin(x_np, out=x_np)
2577
version = x._version
2579
numpy_sin_inplace(x)
2580
self.assertEqual(x, expected)
2581
self.assertGreater(x._version, version)
2583
@torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"})
2585
x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2591
z = [torch.randn(3), torch.randn(3)]
2592
w = [torch.randn(3), None, torch.randn(3)]
2593
initial_versions = pytree.tree_map_only(
2594
torch.Tensor, lambda x: x._version, (x, y, z, w)
2597
new_versions = pytree.tree_map_only(
2598
torch.Tensor, lambda x: x._version, (x, y, z, w)
2601
self.assertEqual(initial_versions[0], new_versions[0])
2602
initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
2603
new_versions, _ = pytree.tree_flatten(new_versions[1:])
2604
for prev, after in zip(initial_versions, new_versions):
2605
if prev is None and after is None:
2607
self.assertGreater(after, prev)
2609
def test_mutated_unknown(self):
2610
@torch.library.custom_op(
2611
"_torch_testing::f", mutates_args="unknown", device_types="cpu"
2613
def f(x: Tensor) -> None:
2615
np.sin(x_np, out=x_np)
2618
version = x._version
2621
self.assertEqual(x, expected)
2622
self.assertGreater(x._version, version)
2624
@torch.library.custom_op("_torch_testing::f2", mutates_args="unknown")
2626
x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2632
z = [torch.randn(3), torch.randn(3)]
2633
w = [torch.randn(3), None, torch.randn(3)]
2634
initial_versions = pytree.tree_map_only(
2635
torch.Tensor, lambda x: x._version, (x, y, z, w)
2638
new_versions = pytree.tree_map_only(
2639
torch.Tensor, lambda x: x._version, (x, y, z, w)
2642
initial_versions, _ = pytree.tree_flatten(initial_versions)
2643
new_versions, _ = pytree.tree_flatten(new_versions)
2644
for prev, after in zip(initial_versions, new_versions):
2645
if prev is None and after is None:
2647
self.assertGreater(after, prev)
2649
with self.assertRaisesRegex(ValueError, "string"):
2651
@torch.library.custom_op("_torch_testing::f3", mutates_args="x")
2652
def f3(x: Tensor) -> None:
2655
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2656
def test_library_register_torch_dispatch_rule_subclass(self):
2657
from torch.testing._internal.two_tensor import TwoTensor
2659
@torch.library.custom_op("mylib::foo", mutates_args={})
2660
def f(x: torch.Tensor) -> torch.Tensor:
2667
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2670
def TwoTensor_foo(cls, func, types, args, kwargs):
2672
assert cls is TwoTensor
2676
m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo)
2681
self.assertEqual(called, 1)
2683
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2684
def test_library_register_torch_dispatch_rule_mode(self):
2685
from torch.testing._internal.two_tensor import TwoTensorMode
2687
@torch.library.custom_op("mylib::foo", mutates_args={})
2688
def f(x: torch.Tensor) -> torch.Tensor:
2693
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2696
def TwoTensor_foo(mode, func, types, args, kwargs):
2701
m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo)
2703
with TwoTensorMode():
2707
self.assertEqual(called, 1)
2709
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2710
@parametrize("idx", [0, 1, 2, 3, 4, 5])
2711
def test_library_register_fake_source(self, idx):
2712
opname = f"source{idx}"
2713
op = getattr(torch.ops._torch_testing, opname).default
2714
entry = torch._library.simple_registry.singleton.find(op._name)
2715
source = entry.fake_impl.kernel.source
2716
assert source is not None
2717
self.assertTrue("custom_op_db.py" in source)
2719
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2720
def test_library_register_fake(self):
2721
for mode in ["function", "qualname", "opoverload"]:
2723
@torch.library.custom_op("_torch_testing::add", mutates_args=())
2724
def add(x: Tensor, y: float) -> Tensor:
2725
x_np = x.cpu().numpy()
2727
return torch.from_numpy(out_np).to(x.device)
2731
if mode == "function":
2732
dec = torch.library.register_fake(add)
2733
self.assertIsNotNone(dec)
2734
elif mode == "qualname":
2735
dec = torch.library.register_fake("_torch_testing::add")
2736
self.assertIsNotNone(dec)
2737
elif mode == "opoverload":
2738
dec = torch.library.register_fake(torch.ops._torch_testing.add.default)
2739
self.assertIsNotNone(dec)
2741
raise AssertionError("should not get here")
2747
return torch.empty_like(x)
2749
with torch._subclasses.fake_tensor.FakeTensorMode():
2753
self.assertEqual(z.shape, x.shape)
2754
self.assertTrue(called)
2756
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2757
def test_library_register_torch_dispatch(self):
2758
for mode in ["function", "qualname", "opoverload"]:
2760
class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2761
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2762
return func(*args, **kwargs)
2764
@torch.library.custom_op("_torch_testing::add", mutates_args=())
2765
def add(x: Tensor, y: float) -> Tensor:
2766
x_np = x.cpu().numpy()
2768
return torch.from_numpy(out_np).to(x.device)
2772
if mode == "function":
2773
dec = torch.library.register_torch_dispatch(add, MyMode)
2774
self.assertIsNotNone(dec)
2775
elif mode == "qualname":
2776
dec = torch.library.register_torch_dispatch(
2777
"_torch_testing::add", MyMode
2779
self.assertIsNotNone(dec)
2780
elif mode == "opoverload":
2781
dec = torch.library.register_torch_dispatch(
2782
torch.ops._torch_testing.add.default, MyMode
2784
self.assertIsNotNone(dec)
2786
raise AssertionError("should not get here")
2789
def _(mode, func, types, args, kwargs):
2792
return func(*args, **kwargs)
2798
self.assertEqual(z.shape, x.shape)
2799
self.assertTrue(called)
2801
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2802
def test_library_register_torch_dispatch_low_level(self):
2803
modes = ["qualname", "opoverload"]
2804
calls = ["decorator", "function"]
2805
device_types_options = [("cpu", "cuda"), "cpu", None]
2807
for mode, call, device_types in itertools.product(
2808
modes, calls, device_types_options
2810
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2811
lib.define("add10(Tensor x, float y) -> Tensor")
2813
if mode == "qualname":
2814
op = "_torch_testing::add10"
2816
assert mode == "opoverload"
2817
op = torch.ops._torch_testing.add10.default
2821
class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2822
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2823
return func(*args, **kwargs)
2825
if call == "decorator":
2827
@torch.library.register_torch_dispatch(op, MyMode, lib=lib)
2828
def _(mode, func, types, args, kwargs):
2835
assert call == "function"
2837
def add_stuff(mode, func, types, args, kwargs):
2843
torch.library.register_torch_dispatch(
2844
op, MyMode, add_stuff, lib=lib
2850
z = torch.ops._torch_testing.add10.default(x, y)
2851
self.assertEqual(z, x + y)
2852
self.assertTrue(called)
2854
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2855
def test_library_register_kernel(self):
2856
modes = ["function", "qualname", "opoverload"]
2857
calls = ["decorator", "function"]
2858
device_types_options = ["cpu", None]
2860
for mode, call, device_types in itertools.product(
2861
modes, calls, device_types_options
2864
@torch.library.custom_op(
2865
"_torch_testing::add", mutates_args=(), device_types="cuda"
2867
def add(x: Tensor, y: float) -> Tensor:
2868
x_np = x.cpu().numpy()
2870
return torch.from_numpy(out_np).to(x.device)
2872
if mode == "function":
2874
elif mode == "qualname":
2875
op = "_torch_testing::add"
2877
assert mode == "opoverload"
2878
op = torch.ops._torch_testing.add.default
2882
if call == "decorator":
2884
@torch.library.register_kernel(op, device_types)
2890
return torch.from_numpy(out_np)
2893
assert call == "function"
2900
return torch.from_numpy(out_np)
2902
torch.library.register_kernel(op, device_types, add_cpu)
2907
self.assertEqual(z, x + y)
2908
self.assertTrue(called)
2910
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2911
def test_library_register_kernel_low_level(self):
2912
modes = ["qualname", "opoverload"]
2913
calls = ["decorator", "function"]
2914
device_types_options = [("cpu", "cuda"), "cpu", None]
2916
for mode, call, device_types in itertools.product(
2917
modes, calls, device_types_options
2919
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2920
lib.define("add9(Tensor x, float y) -> Tensor")
2922
if mode == "qualname":
2923
op = "_torch_testing::add9"
2925
assert mode == "opoverload"
2926
op = torch.ops._torch_testing.add9.default
2930
if call == "decorator":
2932
@torch.library.register_kernel(op, device_types, lib=lib)
2938
return torch.from_numpy(out_np)
2941
assert call == "function"
2948
return torch.from_numpy(out_np)
2950
torch.library.register_kernel(op, device_types, add_cpu, lib=lib)
2954
z = torch.ops._torch_testing.add9.default(x, y)
2955
self.assertEqual(z, x + y)
2956
self.assertTrue(called)
2958
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2959
def test_library_register_autograd(self):
2960
for mode in ["function", "qualname", "opoverload"]:
2962
@torch.library.custom_op("mylib::numpy_sin", mutates_args=())
2963
def numpy_sin(x: Tensor) -> Tensor:
2964
x_np = x.cpu().numpy()
2966
return torch.from_numpy(y_np).to(device=x.device)
2968
def setup_context(ctx, inputs, output) -> Tensor:
2970
ctx.save_for_backward(x)
2974
def backward(ctx, grad):
2977
(x,) = ctx.saved_tensors
2978
return grad * x.cos()
2980
if mode == "function":
2981
torch.library.register_autograd(
2982
numpy_sin, backward, setup_context=setup_context
2984
elif mode == "qualname":
2985
torch.library.register_autograd(
2986
"mylib::numpy_sin", backward, setup_context=setup_context
2988
elif mode == "opoverload":
2989
torch.library.register_autograd(
2990
torch.ops.mylib.numpy_sin.default,
2992
setup_context=setup_context,
2995
x = torch.randn(3, requires_grad=True)
2997
(grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
2998
self.assertTrue(called)
2999
self.assertEqual(grad_x, x.cos())
3001
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3002
def test_library_register_autograd_low_level(self):
3003
for mode in ["qualname", "opoverload"]:
3004
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3005
lib.define("sin5(Tensor x) -> Tensor")
3007
def numpy_sin(x: Tensor) -> Tensor:
3008
x_np = x.cpu().detach().numpy()
3010
return torch.from_numpy(y_np).to(device=x.device)
3012
def setup_context(ctx, inputs, output) -> Tensor:
3014
ctx.save_for_backward(x)
3018
def backward(ctx, grad):
3021
(x,) = ctx.saved_tensors
3022
return grad * x.cos()
3024
lib.impl("sin5", numpy_sin, "CPU")
3028
if mode == "qualname":
3029
torch.library.register_autograd(
3030
"_torch_testing::sin5",
3032
setup_context=setup_context,
3035
elif mode == "opoverload":
3036
torch.library.register_autograd(
3037
torch.ops._torch_testing.sin5.default,
3039
setup_context=setup_context,
3042
x = torch.randn(3, requires_grad=True)
3043
y = torch.ops._torch_testing.sin5(x)
3044
(grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
3045
self.assertTrue(called)
3046
self.assertEqual(grad_x, x.cos())
3048
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3049
def test_fake(self):
3050
@torch.library.custom_op("_torch_testing::add", mutates_args=())
3051
def add(x: Tensor, y: float) -> Tensor:
3052
x_np = x.cpu().numpy()
3054
return torch.from_numpy(out_np).to(x.device)
3059
self.assertEqual(z, x + y)
3062
with torch._subclasses.fake_tensor.FakeTensorMode():
3065
raise AssertionError("should not be hit")
3066
except RuntimeError as e:
3067
abstract_impl_error_msg = str(e)
3068
abstract_impl_error_msg = re.sub(
3069
r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg
3070
).replace(". ", ".\n")
3071
self.assertExpectedInline(
3072
abstract_impl_error_msg,
3074
There was no fake impl registered for <CustomOpDef(_torch_testing::add)>.
3075
This is necessary for torch.compile/export/fx tracing to work.
3076
Please use `add.register_fake` to add an fake impl.""",
3081
@torch.compile(backend="eager")
3086
with self.assertRaisesRegex(RuntimeError, "no fake impl"):
3089
abstract_called = False
3093
nonlocal abstract_called
3094
abstract_called = True
3095
return torch.empty_like(x)
3097
with torch._subclasses.fake_tensor.FakeTensorMode():
3100
self.assertEqual(z.shape, x.shape)
3101
self.assertTrue(abstract_called)
3103
@skipIfTorchDynamo("recursive dynamo")
3104
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
3105
def test_compile(self):
3107
called_abstract = False
3109
@torch.library.custom_op("_torch_testing::linear", mutates_args=())
3110
def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
3111
nonlocal called_impl
3114
w_np = weight.numpy()
3116
out_np = np.add(x_np @ w_np.T, bias)
3119
@custom_linear.register_fake
3120
def _(x, weight, bias):
3121
nonlocal called_abstract
3122
called_abstract = True
3124
assert weight.dim() == 2
3125
assert bias.dim() == 1
3126
assert x.shape[1] == weight.shape[1]
3127
assert weight.shape[0] == bias.shape[0]
3128
assert x.device == weight.device
3129
return x.new_empty(x.size(0), weight.size(0))
3131
x = torch.randn(2, 2)
3132
weight = torch.randn(2, 2)
3133
bias = torch.randn(2)
3134
out = torch.compile(custom_linear, backend="eager", fullgraph=True)(
3137
self.assertEqual(out, torch.nn.functional.linear(x, weight, bias))
3138
self.assertTrue(called_impl)
3139
self.assertTrue(called_abstract)
3141
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3142
def test_register_autograd_error_cases(self):
3143
@torch.library.custom_op("_torch_testing::g", mutates_args=())
3144
def g(x: Tensor) -> Tensor:
3147
x = torch.randn(3, requires_grad=True)
3149
with self.assertRaisesRegex(RuntimeError, "no autograd formula"):
3152
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3153
def test_replacement(self):
3154
@torch.library.custom_op("_torch_testing::f", mutates_args=())
3155
def f(x: Tensor) -> Tensor:
3160
self.assertEqual(y, x.sin())
3162
@torch.library.custom_op("_torch_testing::f", mutates_args=())
3163
def f(x: Tensor) -> Tensor:
3167
self.assertEqual(y, x.cos())
3169
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3170
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
3171
def test_split_device(self):
3175
@torch.library.custom_op(
3176
"_torch_testing::f", mutates_args=(), device_types="cpu"
3178
def f(x: Tensor) -> Tensor:
3179
nonlocal cpu_call_count
3182
out_np = np.sin(x_np)
3183
return torch.from_numpy(out_np)
3185
@f.register_kernel("cuda")
3186
def _(x: Tensor) -> Tensor:
3187
nonlocal cuda_call_count
3188
cuda_call_count += 1
3189
x_np = x.cpu().numpy()
3190
out_np = np.sin(x_np)
3191
return torch.from_numpy(out_np).to(x.device)
3195
self.assertEqual(y, x.sin())
3196
self.assertEqual(cpu_call_count, 1)
3197
self.assertEqual(cuda_call_count, 0)
3201
self.assertEqual(y, x.sin())
3202
self.assertEqual(cpu_call_count, 1)
3203
self.assertEqual(cuda_call_count, 1)
3205
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3206
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
3207
def test_multi_types(self):
3208
@torch.library.custom_op(
3209
"_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda")
3211
def f(x: Tensor) -> Tensor:
3212
x_np = x.cpu().numpy()
3213
out_np = np.sin(x_np)
3214
return torch.from_numpy(out_np).to(x.device)
3218
self.assertEqual(y, x.sin())
3221
self.assertEqual(y, x.sin())
3223
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3224
def test_overloading(self):
3228
@torch.library.custom_op("_torch_testing::f", mutates_args=())
3229
def f(x: Tensor) -> Tensor:
3234
x = torch.randn(2, 3)
3235
torch.ops._torch_testing.f(x)
3236
self.assertEqual(called_f, 1)
3238
@torch.library.custom_op("_torch_testing::f.overload", mutates_args=())
3239
def f1(x: Tensor, y: Tensor) -> Tensor:
3244
torch.ops._torch_testing.f(x, x)
3245
self.assertEqual(called_f1, 1)
3247
def test_disallows_output_aliasing(self):
3248
@torch.library.custom_op("_torch_testing::f", mutates_args=())
3249
def f(x: Tensor) -> Tensor:
3253
with self.assertRaisesRegex(RuntimeError, "may not alias"):
3256
@torch.library.custom_op("_torch_testing::f", mutates_args=())
3257
def f(x: Tensor) -> Tensor:
3261
with self.assertRaisesRegex(RuntimeError, "may not alias"):
3264
@torch.library.custom_op(
3265
"_torch_testing::f", mutates_args={"x"}, device_types="cpu"
3267
def numpy_sin_inplace(x: Tensor) -> Tensor:
3269
np.sin(x_np, out=x_np)
3273
with self.assertRaisesRegex(RuntimeError, "may not alias"):
3274
numpy_sin_inplace(x)
3276
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3277
def test_factory_function(self):
3278
@torch.library.custom_op(
3279
"_torch_testing::f", mutates_args={}, device_types="cpu"
3281
def f(device: torch.device) -> Tensor:
3282
return torch.ones(3)
3284
result = f(device="cpu")
3285
self.assertEqual(result.device, torch.device("cpu"))
3286
self.assertEqual(result, torch.ones(3))
3288
with self.assertRaisesRegex(
3289
RuntimeError, "f does not have a kernel registered for cuda"
3293
with self.assertRaisesRegex(
3295
"Functions without tensor inputs are required to have a `device: torch.device` argument",
3298
@torch.library.custom_op(
3299
"_torch_testing::f2", mutates_args={}, device_types="cpu"
3302
return torch.ones(3)
3304
@torch.library.custom_op("_torch_testing::f3", mutates_args={})
3306
raise NotImplementedError("NYI")
3308
with self.assertRaisesRegex(
3310
"Functions without tensor inputs are required to have a `device: torch.device` argument",
3313
@f3.register_kernel("cpu")
3315
return torch.zeros(3)
3319
@torch.library.custom_op("_torch_testing::f4", mutates_args={})
3320
def f4(device: torch.device) -> Tensor:
3321
raise NotImplementedError("NYI")
3323
@f4.register_kernel("cpu")
3324
def _(device: torch.device):
3325
return torch.zeros(3)
3327
result = f(device="cpu")
3328
self.assertEqual(result.device, torch.device("cpu"))
3329
self.assertEqual(result, torch.ones(3))
3331
def test_library_schema_infer(self):
3332
def foo_impl(x: torch.Tensor) -> torch.Tensor:
3335
schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
3336
self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
3338
schema = torch.library.infer_schema(foo_impl, mutates_args={})
3339
self.assertExpectedInline(schema, "(Tensor x) -> Tensor")
3341
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3342
def test_set_kernel_enabled(self):
3345
@torch.library.custom_op("mylib::f", mutates_args=())
3346
def f(x: Tensor) -> Tensor:
3349
self.assertEqual(f(x), x + 1)
3350
with self.assertLogs("torch._library.custom_ops") as captured:
3351
with f.set_kernel_enabled("gpu", enabled=False):
3352
self.assertEqual(f(x), x + 1)
3354
"no kernel was registered for this device type", captured.output[0]
3357
@f.register_kernel("cpu")
3361
self.assertEqual(f(x), x + 2)
3363
with self.assertLogs("torch._library.custom_ops") as captured:
3364
with f.set_kernel_enabled("cpu", enabled=True):
3365
self.assertEqual(f(x), x + 2)
3366
self.assertIn("already enabled", captured.output[0])
3368
with f.set_kernel_enabled("cpu", enabled=False):
3369
self.assertEqual(f(x), x + 1)
3371
with self.assertLogs("torch._library.custom_ops") as captured:
3372
with f.set_kernel_enabled("cpu", enabled=False):
3373
self.assertEqual(f(x), x + 1)
3374
self.assertIn("already disabled", captured.output[0])
3376
self.assertEqual(f(x), x + 1)
3378
with f.set_kernel_enabled("cpu", enabled=True):
3379
self.assertEqual(f(x), x + 2)
3381
with f.set_kernel_enabled("cpu", enabled=False):
3382
self.assertEqual(f(x), x + 1)
3384
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3385
def test_register_vmap_kwargonly_low_level(self):
3386
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3387
lib.define("foo(Tensor x, *, float y) -> Tensor")
3390
def foo_impl(x, *, y):
3393
lib.impl("foo", foo_impl, "CPU")
3395
def vmap(info, in_dims, x, *, y):
3400
torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3403
result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14)
3404
self.assertTrue(called)
3405
self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14]))
3407
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3408
def test_register_vmap_defaults(self):
3409
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3410
lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
3412
def foo_impl(w, x=2, *, y=3, z):
3413
return w * x * y * z
3415
lib.impl("foo", foo_impl, "CPU")
3419
def vmap(info, in_dims, w, x=2, *, y=3, z):
3422
return w * x * y * z, 0
3424
torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3427
result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42)
3428
self.assertTrue(called)
3429
self.assertEqual(result, w * 2 * 3 * 42)
3431
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3432
def test_library_register_vmap(self):
3433
for mode in ["function", "qualname", "opoverload", "c_opdef"]:
3435
@torch.library.custom_op("mylib::f", mutates_args=())
3436
def f(x: Tensor, y: Tensor) -> Tensor:
3441
def fvmap(info, in_dims, x, y):
3444
x_bdim, y_bdim = in_dims
3445
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3446
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3448
result = result.movedim(-1, 0)
3451
if mode == "function":
3452
torch.library.register_vmap(f, fvmap)
3453
elif mode == "qualname":
3454
torch.library.register_vmap("mylib::f", fvmap)
3455
elif mode == "opoverload":
3456
torch.library.register_vmap(torch.ops.mylib.f.default, fvmap)
3457
elif mode == "c_opdef":
3458
f.register_vmap(fvmap)
3460
x = torch.randn(2, 2)
3461
y = torch.randn(2, 2)
3463
result = torch.vmap(f)(x, y)
3464
self.assertTrue(called)
3465
self.assertEqual(result, x * y)
3468
result = torch.vmap(f, out_dims=1)(x, y)
3469
self.assertEqual(result, (x * y).T)
3470
self.assertTrue(called)
3473
result = torch.vmap(f, in_dims=1)(x, y)
3474
self.assertEqual(result, (x * y).T)
3475
self.assertTrue(called)
3477
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3478
def test_library_register_vmap_library_decorator(self):
3479
@torch.library.custom_op("mylib::f", mutates_args=())
3480
def f(x: Tensor, y: Tensor) -> Tensor:
3485
@torch.library.register_vmap("mylib::f")
3486
def fvmap(info, in_dims, x, y):
3489
x_bdim, y_bdim = in_dims
3490
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3491
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3493
result = result.movedim(-1, 0)
3496
x = torch.randn(2, 2)
3497
y = torch.randn(2, 2)
3499
result = torch.vmap(f)(x, y)
3500
self.assertTrue(called)
3501
self.assertEqual(result, x * y)
3503
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3504
def test_library_register_vmap_op_decorator(self):
3505
@torch.library.custom_op("mylib::f", mutates_args=())
3506
def f(x: Tensor, y: Tensor) -> Tensor:
3512
def fvmap(info, in_dims, x, y):
3515
x_bdim, y_bdim = in_dims
3516
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3517
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3519
result = result.movedim(-1, 0)
3522
x = torch.randn(2, 2)
3523
y = torch.randn(2, 2)
3525
result = torch.vmap(f)(x, y)
3526
self.assertTrue(called)
3527
self.assertEqual(result, x * y)
3529
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3530
def test_library_register_vmap_register_multiple_times(self):
3531
@torch.library.custom_op("mylib::f", mutates_args=())
3532
def f(x: Tensor, y: Tensor) -> Tensor:
3538
def fvmap(info, in_dims, x, y):
3541
x_bdim, y_bdim = in_dims
3542
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3543
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3545
result = result.movedim(-1, 0)
3548
x = torch.randn(2, 2)
3549
y = torch.randn(2, 2)
3551
result = torch.vmap(f)(x, y)
3552
self.assertTrue(called)
3553
self.assertEqual(result, x * y)
3557
def fvmap2(info, in_dims, x, y):
3560
x_bdim, y_bdim = in_dims
3561
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3562
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3564
result = result.movedim(-1, 0)
3567
result = torch.vmap(f)(x, y)
3568
self.assertTrue(called)
3569
self.assertEqual(result, x + y)
3571
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3572
def test_library_register_vmap_register_multiple_times_2(self):
3573
@torch.library.custom_op("mylib::f", mutates_args=())
3574
def f(x: Tensor, y: Tensor) -> Tensor:
3579
@torch.library.register_vmap("mylib::f")
3580
def fvmap(info, in_dims, x, y):
3583
x_bdim, y_bdim = in_dims
3584
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3585
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3587
result = result.movedim(-1, 0)
3590
x = torch.randn(2, 2)
3591
y = torch.randn(2, 2)
3593
result = torch.vmap(f)(x, y)
3594
self.assertTrue(called)
3595
self.assertEqual(result, x * y)
3598
@torch.library.register_vmap("mylib::f")
3599
def fvmap2(info, in_dims, x, y):
3602
x_bdim, y_bdim = in_dims
3603
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3604
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3606
result = result.movedim(-1, 0)
3609
result = torch.vmap(f)(x, y)
3610
self.assertTrue(called)
3611
self.assertEqual(result, x + y)
3614
class MiniOpTestOther(CustomOpTestCaseBase):
3615
test_ns = "mini_op_test"
3617
def test_nonzero_again(self):
3618
x = torch.tensor([0, 1, 2, 0, 0])
3619
y = torch.ops.aten.nonzero.default(x)
3620
self.assertEqual(y, torch.tensor([[1], [2]]))
3623
optests.generate_opcheck_tests(
3625
["aten", "mini_op_test"],
3626
get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3627
additional_decorators={
3628
"test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure]
3630
test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3633
optests.generate_opcheck_tests(
3635
["aten", "mini_op_test"],
3636
get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3637
test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3641
class TestGenerateOpcheckTests(CustomOpTestCaseBase):
3642
def test_MiniOpTest(self):
3643
for orig_test in ["test_mm", "test_nonzero"]:
3646
) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS:
3647
expected_test = f"{test}__{orig_test}"
3648
self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test)
3650
def test_generate_repro_save_data(self):
3651
from torch.testing._internal.optests.generate_tests import generate_repro
3653
args = (torch.ones(2, 2),)
3654
kwargs = {"mat2": torch.zeros(2, 2)}
3655
actual = generate_repro(
3657
torch.ops.aten.sin.default,
3663
actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual)
3664
self.assertExpectedInline(
3667
# =========================================================
3669
# =========================================================
3671
from torch.testing._internal.optests import opcheck
3673
# Make sure you have loaded the library that contains the op
3674
# via an import or torch.ops.load_library(...)
3675
op = torch.ops.aten.sin.default
3677
args, kwargs = torch.load("repro.pt")
3678
opcheck(op, args, kwargs, test_utils="test_schema")
3679
# =========================================================
3681
# =========================================================
3685
def test_generate_repro_no_save_data(self):
3686
from torch.testing._internal.optests.generate_tests import generate_repro
3688
args = (torch.ones(2, 2),)
3689
kwargs = {"mat2": torch.zeros(2, 2)}
3690
actual = generate_repro(
3692
torch.ops.aten.sin.default,
3698
self.assertExpectedInline(
3701
# =========================================================
3703
# =========================================================
3705
from torch.testing._internal.optests import opcheck
3707
# Make sure you have loaded the library that contains the op
3708
# via an import or torch.ops.load_library(...)
3709
op = torch.ops.aten.sin.default
3711
# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1
3712
# we will fill them in same (args, kwargs) as in your test
3713
args = () # args to the operator
3714
kwargs = {} # kwargs to the operator
3715
opcheck(op, args, kwargs, test_utils="test_schema")
3716
# =========================================================
3718
# =========================================================
3722
def test_failures_dict_validation(self):
3723
from torch.testing._internal.optests.generate_tests import (
3725
validate_failures_dict_structure,
3729
"mini_op_test::incorrect_schema": {
3730
"MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": {
3732
"status": "success",
3736
with self.assertRaisesRegex(RuntimeError, "got status=success"):
3737
validate_failures_dict_structure(
3738
FailuresDict("", failures),
3739
torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3744
"mini_op_test::incorrect_schema": {
3745
"MiniOpTest.test_aot_dispatch__test_delayed_error": {
3751
with self.assertRaisesRegex(RuntimeError, "should begin with one of"):
3752
validate_failures_dict_structure(
3753
FailuresDict("", failures),
3754
torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3759
"mini_op_test::incorrect_schema": {
3760
"MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": {
3766
with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"):
3767
validate_failures_dict_structure(
3768
FailuresDict("", failures),
3769
torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3773
def test_dont_generate_decorator(self):
3774
self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
3775
self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
3777
def test_opcheck(self):
3778
x = torch.randn(3, requires_grad=True)
3779
with self.assertRaisesRegex(ValueError, "OpOverload"):
3780
torch.library.opcheck(torch.sin, (x,))
3781
with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
3782
torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
3783
result = torch.library.opcheck(torch.ops.aten.sin.default, (x,))
3788
"test_schema": "SUCCESS",
3789
"test_autograd_registration": "SUCCESS",
3790
"test_faketensor": "SUCCESS",
3791
"test_aot_dispatch_dynamic": "SUCCESS",
3795
result = torch.library.opcheck(
3796
torch.ops.aten.sin.default, (x,), test_utils="test_schema"
3798
self.assertEqual(result, {"test_schema": "SUCCESS"})
3800
result = torch.library.opcheck(
3801
torch.ops.aten.sin.default,
3803
test_utils=["test_schema", "test_faketensor"],
3808
"test_schema": "SUCCESS",
3809
"test_faketensor": "SUCCESS",
3813
def test_opcheck_customopdef(self):
3816
(torch.randn(3, requires_grad=True),),
3818
if torch.cuda.is_available():
3819
sample_inputs.extend(
3821
(torch.randn(3, device="cuda"),),
3822
(torch.randn(3, device="cuda", requires_grad=True),),
3825
for args in sample_inputs:
3826
torch.library.opcheck(custom_op_db.numpy_cube, args)
3828
def test_is_inside_opcheck_mode(self):
3829
self.assertFalse(optests.is_inside_opcheck_mode())
3830
with optests.generate_tests.OpCheckMode(
3831
["foo"], "bar", lambda x: x, None, "baz", "brr"
3833
self.assertTrue(optests.is_inside_opcheck_mode())
3835
def test_opcheck_bad_op(self):
3836
op = op_with_incorrect_schema(self, "foo")
3838
with self.assertRaisesRegex(Exception, "is not defined to alias output"):
3839
torch.library.opcheck(op, (x,))
3841
result = torch.library.opcheck(op, (x,), raise_exception=False)
3842
self.assertTrue(isinstance(result["test_schema"], RuntimeError))
3843
del result["test_schema"]
3847
"test_autograd_registration": "SUCCESS",
3848
"test_faketensor": "SUCCESS",
3849
"test_aot_dispatch_dynamic": "SUCCESS",
3853
def test_opcheck_does_not_require_extra_deps(self):
3854
# torch.testing._internal.common_utils comes with a lot of additional
3855
# test-time dependencies. Since opcheck is public API, it should be
3856
# usable only with pytorch install-time dependencies.
3860
"import torch; import sys; \
3861
x = torch.randn(3, requires_grad=True); \
3862
torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \
3863
assert 'expecttest' not in sys.modules; \
3864
assert 'torch.testing._internal.common_utils' not in sys.modules",
3866
subprocess.check_output(cmd, shell=False)
3869
class TestTypeConversion(TestCase):
3870
"""In infer_schema(), we try to suggest a correct type when the type annotation is wrong."""
3873
self.supported_base_types = [
3884
def test_simple_tuple(self):
3885
self.assertEqual(List, tuple_to_list(Tuple))
3887
def test_supported_types(self):
3888
for t in self.supported_base_types:
3889
result_type = tuple_to_list(Tuple[t, t, t])
3890
self.assertEqual(result_type, List[t])
3892
result_type = tuple_to_list(Tuple[t])
3893
self.assertEqual(result_type, List[t])
3895
def test_optional(self):
3896
for t in self.supported_base_types:
3897
result_type = tuple_to_list(Tuple[t, Optional[t]])
3898
self.assertEqual(result_type, List[Optional[t]])
3900
result_type = tuple_to_list(Tuple[t, t, Optional[t]])
3901
self.assertEqual(result_type, List[Optional[t]])
3903
result_type = tuple_to_list(Tuple[t, ...])
3904
self.assertEqual(result_type, List[t])
3906
def test_mixed_types(self):
3907
result_type = tuple_to_list(Tuple[int, float])
3908
self.assertEqual(result_type, List[typing.Union[int, float]])
3910
result_type = tuple_to_list(Tuple[int, float, str])
3911
self.assertEqual(result_type, List[typing.Union[int, float, str]])
3914
only_for = ("cpu", "cuda")
3915
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
3916
instantiate_parametrized_tests(TestCustomOp)
3917
instantiate_parametrized_tests(TestCustomOpAPI)
3919
if __name__ == "__main__":