pytorch

Форк
0
/
test_custom_ops.py 
2242 строки · 77.2 Кб
1
# Owner(s): ["module: custom-operators"]
2

3
from torch.testing._internal.common_utils import *  # noqa: F403
4
from torch.testing._internal.common_device_type import *  # noqa: F403
5
import collections
6

7
import itertools
8
import os
9
import re
10
import sys
11
import typing
12

13
import torch._custom_ops as custom_ops
14

15
import torch.testing._internal.custom_op_db
16
import torch.testing._internal.optests as optests
17
import torch.utils.cpp_extension
18
from functorch import make_fx
19
from torch import Tensor
20
from torch._custom_op.impl import custom_op, CustomOp
21
from torch._utils_internal import get_file_path_2
22
from torch.testing._internal.common_cuda import TEST_CUDA
23
from torch.testing._internal.custom_op_db import custom_op_db
24
from typing import *  # noqa: F403
25

26

27
class CustomOpTestCaseBase(TestCase):
28
    test_ns = "_test_custom_op"
29

30
    def setUp(self):
31
        self.libraries = []
32

33
    def tearDown(self):
34
        import torch._custom_op
35

36
        keys = list(torch._custom_op.impl.global_registry.keys())
37
        for key in keys:
38
            if not key.startswith(f"{self.test_ns}::"):
39
                continue
40
            torch._custom_op.impl.global_registry[key]._destroy()
41
        if hasattr(torch.ops, self.test_ns):
42
            delattr(torch.ops, self.test_ns)
43
        for lib in self.libraries:
44
            lib._destroy()
45
        del self.libraries
46

47
    def ns(self):
48
        return getattr(torch.ops, self.test_ns)
49

50
    def lib(self):
51
        result = torch.library.Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
52
        self.libraries.append(result)
53
        return result
54

55
    def get_op(self, qualname):
56
        return torch._custom_op.impl.get_op(qualname)
57

58

59
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
60
@unittest.skipIf(
61
    sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
62
)
63
class TestCustomOpTesting(CustomOpTestCaseBase):
64
    @parametrize("check_gradients", (False, "auto"))
65
    @parametrize("dynamic", (True, False))
66
    def test_aot_autograd_check_degenerate_cases(
67
        self, device, dynamic, check_gradients
68
    ):
69
        def simple(x):
70
            return x.clone()
71

72
        # Should not raise
73
        x = torch.randn(3, device=device)
74
        optests.aot_autograd_check(
75
            simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
76
        )
77

78
        def outputs_dont_require_grad(x):
79
            return x.detach()
80

81
        # Should not raise
82
        y = torch.randn(3, device=device, requires_grad=True)
83
        optests.aot_autograd_check(
84
            simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
85
        )
86

87
        def no_outputs(x):
88
            return x.detach()
89

90
        # Should not raise
91
        x = torch.randn(3, device=device, requires_grad=True)
92
        y = torch.randn(3, device=device, requires_grad=False)
93
        optests.aot_autograd_check(
94
            no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
95
        )
96
        optests.aot_autograd_check(
97
            no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
98
        )
99

100
    def test_incorrect_schema_mutation(self, device):
101
        lib = self.lib()
102
        lib.define("foo(Tensor x) -> Tensor")
103
        op = self.ns().foo.default
104

105
        class Foo(torch.autograd.Function):
106
            @staticmethod
107
            def forward(ctx, x):
108
                guard = torch._C._AutoDispatchBelowAutograd()
109
                try:
110
                    return op(x)
111
                finally:
112
                    del guard
113

114
            @staticmethod
115
            def backward(ctx, gx):
116
                return gx
117

118
        def foo_impl(x):
119
            x.sin_()
120
            return x.clone()
121

122
        lib.impl("foo", Foo.apply, "Autograd")
123
        lib.impl("foo", foo_impl, "CPU")
124
        lib.impl("foo", foo_impl, "CUDA")
125

126
        x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
127
        with self.assertRaisesRegex(
128
            optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
129
        ):
130
            optests.opcheck(op, (x,), {})
131

132
    def test_incorrect_schema_view(self, device):
133
        lib = self.lib()
134
        lib.define("foo(Tensor x) -> Tensor")
135
        op = self.ns().foo.default
136

137
        class Foo(torch.autograd.Function):
138
            @staticmethod
139
            def forward(ctx, x):
140
                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
141
                with torch._C._AutoDispatchBelowAutograd():
142
                    with torch._C._ExcludeDispatchKeyGuard(
143
                        torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
144
                    ):
145
                        return op(x)
146

147
            @staticmethod
148
            def backward(ctx, gx):
149
                return gx
150

151
        def foo_impl(x):
152
            return x.view_as(x)
153

154
        def foo_meta(x):
155
            return x.view_as(x)
156

157
        lib.impl("foo", Foo.apply, "Autograd")
158
        lib.impl("foo", foo_impl, "CPU")
159
        lib.impl("foo", foo_meta, "Meta")
160

161
        x = torch.tensor(3.14159 / 3, requires_grad=True)
162
        with self.assertRaisesRegex(
163
            optests.OpCheckError,
164
            "Argument x is not defined to alias output but was aliasing",
165
        ):
166
            optests.opcheck(op, (x,), {})
167

168
    def test_missing_abstract_impl(self, device):
169
        lib = self.lib()
170
        lib.define("foo(Tensor x) -> Tensor")
171
        op = self.ns().foo.default
172

173
        class Foo(torch.autograd.Function):
174
            @staticmethod
175
            def forward(ctx, x):
176
                with torch._C._AutoDispatchBelowAutograd():
177
                    return op(x)
178

179
            @staticmethod
180
            def backward(ctx, gx):
181
                return 2 * gx
182

183
        def foo_impl(x):
184
            return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
185

186
        lib.impl("foo", Foo.apply, "Autograd")
187
        lib.impl("foo", foo_impl, "CPU")
188
        lib.impl("foo", foo_impl, "CUDA")
189

190
        x = torch.tensor([0, 1.0], requires_grad=True)
191
        with self.assertRaisesRegex(
192
            optests.OpCheckError,
193
            "_test_custom_op.foo.default",
194
        ):
195
            optests.opcheck(op, (x,), {})
196

197
    def test_incorrect_abstract_impl(self, device):
198
        lib = self.lib()
199
        lib.define("foo(Tensor x) -> Tensor")
200
        op = self.ns().foo.default
201

202
        class Foo(torch.autograd.Function):
203
            @staticmethod
204
            def forward(ctx, x):
205
                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
206
                guard = torch._C._AutoDispatchBelowAutograd()
207
                guard2 = torch._C.ExcludeDispatchKeyGuard(
208
                    torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
209
                )
210
                try:
211
                    return op(x)
212
                finally:
213
                    del guard
214
                    del guard2
215

216
            @staticmethod
217
            def backward(ctx, gx):
218
                return gx
219

220
        def foo_impl(x):
221
            return x**2
222

223
        def foo_meta(x):
224
            return x.unsqueeze(1) ** 2
225

226
        lib.impl("foo", Foo.apply, "Autograd")
227
        lib.impl("foo", foo_impl, "CPU")
228
        lib.impl("foo", foo_impl, "CUDA")
229
        lib.impl("foo", foo_meta, "Meta")
230

231
        x = torch.tensor([0, 1.0], requires_grad=True)
232
        with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
233
            optests.opcheck(op, (x,), {})
234

235
    def test_missing_functionalization(self, device):
236
        lib = self.lib()
237
        lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
238
        op = self.ns().foo.default
239

240
        class Foo(torch.autograd.Function):
241
            @staticmethod
242
            def forward(ctx, x):
243
                ctx.mark_dirty(x)
244
                with torch._C._AutoDispatchBelowAutograd():
245
                    return op(x)
246

247
            @staticmethod
248
            def backward(ctx, gx):
249
                return gx
250

251
        def foo_impl(x):
252
            return x.sin_()
253

254
        def foo_meta(x):
255
            return x
256

257
        lib.impl("foo", Foo.apply, "Autograd")
258
        lib.impl("foo", foo_impl, "CPU")
259
        lib.impl("foo", foo_impl, "CUDA")
260
        lib.impl("foo", foo_meta, "Meta")
261

262
        x = torch.tensor([0, 1.0])
263
        y = x.clone()
264
        with self.assertRaisesRegex(
265
            optests.OpCheckError,
266
            "Getting these operators to work with functionalization requires some extra work",
267
        ):
268
            optests.opcheck(op, (y,), {})
269

270
    def test_autograd_registered_at_backend(self, device):
271
        lib = self.lib()
272
        lib.define("foo(Tensor x) -> Tensor")
273
        op = self.ns().foo.default
274

275
        class Foo(torch.autograd.Function):
276
            @staticmethod
277
            def forward(ctx, x):
278
                return x.clone()
279

280
            @staticmethod
281
            def backward(ctx, gx):
282
                return gx * 0.5
283

284
        lib.impl("foo", Foo.apply, "CPU")
285
        lib.impl("foo", Foo.apply, "CUDA")
286
        lib.impl("foo", lambda x: x.clone(), "Meta")
287

288
        x = torch.randn([], requires_grad=True)
289

290
        with self.assertRaisesRegex(
291
            torch.testing._internal.optests.OpCheckError,
292
            "does not have an autograd kernel",
293
        ):
294
            optests.opcheck(op, (x,), {})
295

296
        # I'm not sure why this is necessary
297
        del lib
298

299
    def test_global_state_mutation(self, device):
300
        lib = self.lib()
301
        lib.define("foo(Tensor x) -> Tensor")
302
        op = self.ns().foo.default
303

304
        class Foo(torch.autograd.Function):
305
            invoked = 0
306

307
            @staticmethod
308
            def forward(ctx, x):
309
                Foo.invoked += 1
310
                return x.clone() * Foo.invoked
311

312
            @staticmethod
313
            def backward(ctx, gx):
314
                return gx
315

316
        lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
317

318
        x = torch.tensor(3.14159 / 3, requires_grad=True)
319
        with self.assertRaisesRegex(
320
            optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
321
        ):
322
            optests.opcheck(op, (x,), {})
323

324
    @ops(custom_op_db, dtypes=OpDTypes.any_one)
325
    def test_opcheck_opinfo(self, device, dtype, op):
326
        for sample_input in op.sample_inputs(
327
            device, dtype, requires_grad=op.supports_autograd
328
        ):
329
            args = [sample_input.input] + list(sample_input.args)
330
            kwargs = sample_input.kwargs
331
            if op.op in (
332
                torch.ops._torch_testing.numpy_nonzero,
333
                torch.ops._torch_testing.numpy_nms,
334
            ):
335
                ctx = self.assertRaisesRegex(optests.OpCheckError, "failed with")
336
            else:
337
                ctx = contextlib.nullcontext()
338
            with ctx:
339
                optests.opcheck(
340
                    op.op,
341
                    args,
342
                    kwargs,
343
                )
344

345
    def test_opcheck_fails_basic(self, device):
346
        @custom_op(f"{self.test_ns}::foo")
347
        def foo(x: torch.Tensor) -> torch.Tensor:
348
            ...
349

350
        @foo.impl(["cpu", "cuda"])
351
        def foo_impl(x):
352
            return x.sum()
353

354
        x = torch.randn(3, device=device, requires_grad=True)
355
        # Triggers the CustomOp autograd NYI error
356
        with self.assertRaisesRegex(
357
            optests.OpCheckError, "Autograd has not been implemented for operator"
358
        ):
359
            optests.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
360

361
    def test_autograd_registration_check_autograd_kernel(self, device):
362
        lib = self.lib()
363
        lib.define("foo(Tensor x) -> Tensor")
364
        op = self.ns().foo.default
365

366
        class Foo(torch.autograd.Function):
367
            @staticmethod
368
            def forward(ctx, x):
369
                with torch._C._AutoDispatchBelowAutograd():
370
                    return op(x)
371

372
            @staticmethod
373
            def backward(ctx, gx):
374
                return gx
375

376
        def foo_impl(x):
377
            return x.sin()
378

379
        lib.impl("foo", Foo.apply, "Autograd")
380
        lib.impl("foo", foo_impl, "CPU")
381
        lib.impl("foo", foo_impl, "CUDA")
382

383
        x = torch.randn(3, requires_grad=True, device=device)
384
        # Should not raise
385
        optests.autograd_registration_check(op, (x,), {})
386

387
    def test_autograd_registration_check_compositeimplicitautograd(self, device):
388
        lib = self.lib()
389
        lib.define("foo(Tensor x) -> Tensor")
390
        op = self.ns().foo.default
391

392
        def foo_impl(x):
393
            return x.sin().cos()
394

395
        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
396

397
        x = torch.randn(3, requires_grad=True, device=device)
398
        # Should not raise
399
        optests.autograd_registration_check(op, (x,), {})
400

401
    def test_autograd_registration_check_incorrect_composite(self, device):
402
        lib = self.lib()
403
        lib.define("foo(Tensor x) -> Tensor")
404
        op = self.ns().foo.default
405

406
        def foo_impl(x):
407
            return x.sin().cos()
408

409
        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
410

411
        x = torch.randn(3, requires_grad=True, device=device)
412
        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
413
            optests.autograd_registration_check(op, (x,), {})
414

415
    def test_autograd_registration_check_incorrect(self, device):
416
        lib = self.lib()
417
        lib.define("foo(Tensor x) -> Tensor")
418
        op = self.ns().foo.default
419

420
        class Foo(torch.autograd.Function):
421
            @staticmethod
422
            def forward(ctx, x):
423
                return torch.sin(x)
424

425
            @staticmethod
426
            def backward(ctx, gx):
427
                return gx
428

429
        lib.impl("foo", Foo.apply, "CPU")
430
        lib.impl("foo", Foo.apply, "CUDA")
431

432
        x = torch.randn(3, requires_grad=True, device=device)
433
        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
434
            optests.autograd_registration_check(op, (x,), {})
435

436
    def test_assert_raises_regex(self, device):
437
        from torch.testing._internal.optests.aot_autograd import assert_raises_regex
438

439
        with assert_raises_regex(RuntimeError, "c"):
440
            raise RuntimeError("abcd")
441
        with assert_raises_regex(RuntimeError, "c.*"):
442
            raise RuntimeError("abcd")
443
        with self.assertRaisesRegex(AssertionError, "instead got"):
444
            with assert_raises_regex(RuntimeError, "c.*"):
445
                raise ValueError("abcd")
446
        with self.assertRaisesRegex(AssertionError, "Expected exception"):
447
            with assert_raises_regex(RuntimeError, "c.*"):
448
                pass
449
        with self.assertRaisesRegex(AssertionError, "to match regex"):
450
            with assert_raises_regex(RuntimeError, "f"):
451
                raise RuntimeError("abcd")
452

453

454
class TestCustomOp(CustomOpTestCaseBase):
455
    test_ns = "_test_custom_op"
456

457
    def test_invalid_schemas(self):
458
        # function schmea validation goes through torchgen, so this is just a
459
        # basic test.
460
        with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"):
461
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(")
462

463
    def test_invalid_qualname(self):
464
        with self.assertRaisesRegex(ValueError, "overload"):
465
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()")
466

467
    def test_name_must_match(self):
468
        with self.assertRaisesRegex(ValueError, "to have name"):
469

470
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
471
            def baz(x: Tensor) -> Tensor:
472
                raise NotImplementedError()
473

474
    def test_unsupported_schemas(self):
475
        with self.assertRaisesRegex(ValueError, "only supports functional"):
476
            custom_ops.custom_op(
477
                f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)"
478
            )(foo)
479
        with self.assertRaisesRegex(ValueError, "only supports functional"):
480
            custom_ops.custom_op(
481
                f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)"
482
            )(foo)
483
        with self.assertRaisesRegex(ValueError, "only supports functional"):
484
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")(
485
                foo
486
            )
487
        with self.assertRaisesRegex(ValueError, "self"):
488
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")(
489
                foo
490
            )
491

492
    # Tests for the older custom_op API
493
    def test_schema_matches_signature(self):
494
        with self.assertRaisesRegex(ValueError, "signature to match"):
495

496
            @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor")
497
            def blah(x):
498
                pass
499

500
        with self.assertRaisesRegex(ValueError, "signature to match"):
501

502
            @custom_op(
503
                f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor"
504
            )
505
            def blah2(x, y):
506
                pass
507

508
        with self.assertRaisesRegex(ValueError, "signature to match"):
509

510
            @custom_op(
511
                f"{TestCustomOp.test_ns}::blah3",
512
                "(Tensor x, *, Tensor w, Tensor z) -> Tensor",
513
            )
514
            def blah3(x, *, y, z):
515
                pass
516

517
        with self.assertRaisesRegex(ValueError, "signature to match"):
518

519
            @custom_op(
520
                f"{TestCustomOp.test_ns}::blah4",
521
                "(Tensor x, *, Tensor z, Tensor y) -> Tensor",
522
            )
523
            def blah4(x, *, y, z):
524
                pass
525

526
        with self.assertRaisesRegex(ValueError, "not supported"):
527

528
            @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor")
529
            def blah5(*args):
530
                pass
531

532
        with self.assertRaisesRegex(ValueError, "not supported"):
533

534
            @custom_op(
535
                f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor"
536
            )
537
            def blah6(**kwargs):
538
                pass
539

540
        with self.assertRaisesRegex(ValueError, "default arguments"):
541

542
            @custom_op(
543
                f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor"
544
            )
545
            def blah7(x=1, *, y):
546
                pass
547

548
        with self.assertRaisesRegex(ValueError, "default arguments"):
549

550
            @custom_op(
551
                f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor"
552
            )
553
            def blah8(x, *, y=1):
554
                pass
555

556
        # kwonly-arg works
557
        @custom_op(
558
            f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor"
559
        )
560
        def blah9(x, *, y):
561
            pass
562

563
    # Tests for the older custom_op API
564
    def test_unsupported_annotation_categories(self):
565
        with self.assertRaisesRegex(ValueError, "varargs"):
566

567
            @custom_op(f"{TestCustomOp.test_ns}::foo")
568
            def foo(*args):
569
                raise NotImplementedError()
570

571
            del foo
572

573
        with self.assertRaisesRegex(ValueError, "varkwargs"):
574

575
            @custom_op(f"{TestCustomOp.test_ns}::foo")
576
            def foo(**kwargs):
577
                raise NotImplementedError()
578

579
            del foo
580

581
        with self.assertRaisesRegex(ValueError, "must have a type annotation"):
582

583
            @custom_op(f"{TestCustomOp.test_ns}::foo")
584
            def foo(x):
585
                raise NotImplementedError()
586

587
            del foo
588

589
        with self.assertRaisesRegex(ValueError, "default value"):
590

591
            @custom_op(f"{TestCustomOp.test_ns}::foo")
592
            def foo(x: Optional[Tensor] = None):
593
                raise NotImplementedError()
594

595
            del foo
596

597
        with self.assertRaisesRegex(ValueError, "default value"):
598

599
            @custom_op(f"{TestCustomOp.test_ns}::foo")
600
            def foo(x: Optional[Tensor] = None):
601
                raise NotImplementedError()
602

603
            del foo
604

605
        with self.assertRaisesRegex(ValueError, "unsupported"):
606

607
            @custom_op(f"{TestCustomOp.test_ns}::foo")
608
            def foo(x: Tensor) -> Tuple[Tensor, ...]:
609
                raise NotImplementedError()
610

611
            del foo
612

613
    def _generate_examples(self, typ):
614
        if typ is int:
615
            return [17]
616
        if typ is float:
617
            return [3.14]
618
        if typ is bool:
619
            return [True]
620
        if typ is str:
621
            return ["foo"]
622
        if typ is torch.dtype:
623
            return [torch.float32]
624
        if typ is torch.device:
625
            return [torch.device("cpu")]
626
        if typ == torch.types.Number:
627
            return [2.718]
628
        if typ is torch.Tensor:
629
            return [torch.tensor(3)]
630
        if typ == Optional[torch.types.Number]:
631
            return [None, 2.718]
632
        origin = typing.get_origin(typ)
633
        if origin is Union:
634
            args = typing.get_args(typ)
635
            assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
636
            elt = args[0] if args[1] is type(None) else args[1]
637
            return self._generate_examples(elt) + [None]
638
        if origin is list:
639
            args = typing.get_args(typ)
640
            assert len(args) == 1
641
            elt = args[0]
642
            return [
643
                self._generate_examples(elt),
644
                self._generate_examples(elt),
645
                self._generate_examples(elt),
646
            ]
647
        if origin is collections.abc.Sequence:
648
            args = typing.get_args(typ)
649
            assert len(args) == 1
650
            examples = self._generate_examples(args[0])
651
            return list(itertools.product(examples, examples)) + []
652
        raise NotImplementedError(
653
            f"testrunner cannot generate instanstance of type {typ}"
654
        )
655

656
    def test_supported_return_types_single_return(self):
657
        for typ in torch._custom_op.impl.SUPPORTED_RETURN_TYPES:
658
            for example in self._generate_examples(typ):
659
                try:
660

661
                    @custom_ops.custom_op(f"{self.test_ns}::foo")
662
                    def foo(x: Tensor) -> typ:
663
                        raise NotImplementedError()
664

665
                    @custom_ops.impl(f"{self.test_ns}::foo")
666
                    def foo_impl(x: Tensor) -> typ:
667
                        return example
668

669
                    op = self.get_op(f"{self.test_ns}::foo")
670
                    result = op(torch.randn([]))
671
                    self.assertEqual(result, example, msg=f"{typ} {example}")
672
                finally:
673
                    custom_ops._destroy(f"{self.test_ns}::foo")
674

675
    def test_supported_return_types_multi_return(self):
676
        for typ in torch._custom_op.impl.SUPPORTED_RETURN_TYPES:
677
            for example in self._generate_examples(typ):
678
                try:
679

680
                    @custom_ops.custom_op(f"{self.test_ns}::foo")
681
                    def foo(x: Tensor) -> Tuple[typ, typ]:
682
                        raise NotImplementedError()
683

684
                    @custom_ops.impl(f"{self.test_ns}::foo")
685
                    def foo_impl(x: Tensor) -> Tuple[typ, typ]:
686
                        return (example, example)
687

688
                    op = self.get_op(f"{self.test_ns}::foo")
689
                    result = op(torch.randn([]))
690
                    expected = (example, example)
691
                    self.assertEqual(result, expected, msg=f"{typ} {example}")
692
                finally:
693
                    custom_ops._destroy(f"{self.test_ns}::foo")
694

695
    def test_supported_param_types(self):
696
        for typ in torch._custom_op.impl.SUPPORTED_PARAM_TYPES:
697

698
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
699
            def foo(x: Tensor, y: typ) -> Tensor:
700
                raise NotImplementedError()
701

702
            yeet = None
703

704
            @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"])
705
            def foo_cpu(x, y):
706
                nonlocal yeet
707
                yeet = y
708
                return x.clone()
709

710
            try:
711
                for example in self._generate_examples(typ):
712
                    op = self.get_op(f"{self.test_ns}::foo")
713
                    op(torch.randn([]), example)
714
                    self.assertEqual(yeet, example, msg=f"{typ} {example}")
715
                    yeet = None
716
            finally:
717
                custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
718

719
    def test_sequences(self):
720
        # Sequence[int] gets automagically turned into int[] in the schema.
721
        # This test checks that we actually do support arbitrary sequence types.
722
        class MySequence(collections.abc.Sequence):
723
            def __init__(self):
724
                self._container = [1, 2, 3]
725

726
            def __getitem__(self, idx):
727
                return self._container[idx]
728

729
            def __len__(self):
730
                return len(self._container)
731

732
        @custom_ops.custom_op(f"{self.test_ns}::foo")
733
        def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
734
            raise NotImplementedError()
735

736
        called = 0
737

738
        @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu")
739
        def foo_cpu(x, sizes):
740
            nonlocal called
741
            called += 1
742
            # Dispatcher will normalize the sequence type into a List
743
            self.assertEqual(sizes, [1, 2, 3])
744
            return x.clone()
745

746
        x = torch.randn([])
747
        seq = MySequence()
748
        op = self.get_op(f"{self.test_ns}::foo")
749
        op(x, seq)
750
        self.assertEqual(called, 1)
751

752
    def test_unsupported_param_types(self):
753
        # Not comprehensive (it doesn't need to be), just a check that our mechanism works
754
        with self.assertRaisesRegex(ValueError, "unsupported type"):
755

756
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
757
            def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
758
                raise NotImplementedError()
759

760
            del foo
761

762
        with self.assertRaisesRegex(ValueError, "unsupported type"):
763
            # int[N] in Dispatcher is a bit wild, so we don't try to support it.
764
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
765
            def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
766
                raise NotImplementedError()
767

768
            del foo
769

770
        with self.assertRaisesRegex(ValueError, "unsupported type"):
771
            # We could theoretically support this, but the syntax for suporting
772
            # int[] is Sequence[int]
773
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
774
            def foo(x: Tensor, y: List[int]) -> Tensor:
775
                raise NotImplementedError()
776

777
            del foo
778

779
        with self.assertRaisesRegex(ValueError, "unsupported type"):
780

781
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
782
            def foo(x: Tensor, y: Callable) -> Tensor:
783
                raise NotImplementedError()
784

785
            del foo
786

787
    def test_supported_schemas(self):
788
        # All of these should already be tested by PyTorch codegen
789
        # (we share the same mechanism), but here's a sanity check.
790
        schemas = [
791
            "(Tensor x) -> Tensor",
792
            "(Tensor x) -> Tensor y",
793
            "(Tensor[] x) -> Tensor y",
794
            "(Tensor x) -> (Tensor, Tensor)",
795
            "(Tensor x) -> (Tensor y, Tensor z)",
796
            "(Tensor x) -> (Tensor y, Tensor z)",
797
        ]
798
        other_schemas = [
799
            "(Tensor x, Tensor w) -> (Tensor y, Tensor z)",
800
            "(Tensor x, Tensor w) -> (Tensor, Tensor)",
801
            "(Tensor x, Tensor w) -> Tensor",
802
            "(Tensor? x, Tensor w) -> Tensor",
803
            "(Tensor? x, Tensor[] w) -> Tensor",
804
            "(Tensor x, int[] w) -> Tensor",
805
            "(Tensor x, SymInt[] w) -> Tensor",
806
            "(Tensor x, Scalar w) -> Tensor",
807
            "(Tensor x, float w) -> Tensor",
808
            "(Tensor x, float? w) -> Tensor",
809
            "(Tensor x, bool[] w) -> Tensor",
810
        ]
811

812
        for schema in schemas:
813
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema)
814
            custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
815
        for schema in other_schemas:
816
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema)
817
            custom_ops._destroy(f"{TestCustomOp.test_ns}::bar")
818

819
    def test_reserved_ns(self):
820
        from torch._custom_op.impl import RESERVED_NS
821

822
        for ns in RESERVED_NS:
823
            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
824
                custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor")
825

826
            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
827

828
                @custom_ops.custom_op(f"{ns}::foo2")
829
                def foo2(x: torch.Tensor) -> torch.Tensor:
830
                    raise NotImplementedError()
831

832
    def test_private_ctor(self):
833
        with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
834
            CustomOp(None, None, None, None, None)
835

836
    def test_lifetime(self):
837
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
838
        def foo(x: torch.Tensor) -> torch.Tensor:
839
            raise NotImplementedError()
840

841
        custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
842

843
        # We can't define an op multiple times,
844
        with self.assertRaisesRegex(RuntimeError, "multiple times"):
845

846
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
847
            def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
848
                raise NotImplementedError()
849

850
        # Unless we delete the original op.
851
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
852

853
        # Smoke test
854
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
855
        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
856
            raise NotImplementedError()
857

858
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
859

860
    def test_autograd_notimplemented(self):
861
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
862
        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
863
            raise NotImplementedError()
864

865
        x = torch.randn(3, requires_grad=True)
866
        op = self.get_op(f"{self.test_ns}::foo")
867
        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
868
            op(x)
869
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
870
        del foo
871

872
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
873
        def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
874
            raise NotImplementedError()
875

876
        x = torch.randn(3, requires_grad=True)
877
        y = torch.randn(3)
878
        op = self.get_op(f"{self.test_ns}::foo")
879
        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
880
            op([y, x])
881
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
882
        del foo
883

884
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
885
        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
886
            raise NotImplementedError()
887

888
        x = torch.randn(3, requires_grad=True)
889
        y = torch.randn(3)
890
        op = self.get_op(f"{self.test_ns}::foo")
891
        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
892
            op(y, x)
893
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
894

895
    def test_autograd_notimplemented_gradmode(self):
896
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
897
        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
898
            raise NotImplementedError()
899

900
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
901
        def foo_impl(x, y):
902
            return x * y
903

904
        x = torch.randn(3, requires_grad=True)
905
        y = torch.randn(3)
906
        op = self.get_op(f"{self.test_ns}::foo")
907
        with torch.no_grad():
908
            # Shouldn't raise, because we are in no_grad
909
            op(y, x)
910

911
    def test_impl_cpu(self):
912
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
913
        def foo(x: torch.Tensor) -> torch.Tensor:
914
            raise NotImplementedError()
915

916
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
917
        def foo_cpu(x):
918
            return x.sin()
919

920
        x = torch.randn(3)
921
        op = self.get_op(f"{self.test_ns}::foo")
922
        result = op(x)
923
        self.assertEqual(result, foo_cpu(x))
924

925
    def test_impl_invalid_devices(self):
926
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
927
        def foo(x: torch.Tensor) -> torch.Tensor:
928
            raise NotImplementedError()
929

930
        def foo_impl(x):
931
            return x.sin()
932

933
        from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
934

935
        for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
936
            # Smoke test: should not raise error
937
            custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)(
938
                foo_impl
939
            )
940

941
        # Not supported by this API: we can either support them in the future
942
        # or provide some other CustomOp.def_* function. This depends on how
943
        # common the use cases are.
944
        for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]:
945
            with self.assertRaisesRegex(ValueError, "we only support device_type"):
946
                custom_ops.impl(
947
                    f"{TestCustomOp.test_ns}::foo", device_types=invalid_type
948
                )(foo_impl)
949

950
    def test_backward_partially_registered(self):
951
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
952
        def foo(x: torch.Tensor) -> torch.Tensor:
953
            raise NotImplementedError()
954

955
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
956
        def foo_impl(x):
957
            return x.sin()
958

959
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
960
        def foo_backward(ctx, saved, grad):
961
            return grad * saved.cos()
962

963
        x = torch.randn([], requires_grad=True)
964
        op = self.get_op(f"{self.test_ns}::foo")
965
        with self.assertRaisesRegex(
966
            RuntimeError, "unable to find a 'save_for_backward'"
967
        ):
968
            y = op(x)
969
            y.backward()
970

971
    def test_save_for_backward_inputs_are_namedtuple(self):
972
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
973
        def foo(x: torch.Tensor) -> torch.Tensor:
974
            raise NotImplementedError()
975

976
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
977
        def foo_impl(x):
978
            return x.sin()
979

980
        hit = 0
981

982
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
983
        def foo_save_for_backward(inputs, output):
984
            nonlocal hit
985
            hit += 1
986
            self.assertTrue(isinstance(inputs, tuple))
987
            self.assertEqual(list(inputs._asdict().keys()), ["x"])
988
            return inputs.x
989

990
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
991
        def foo_backward(ctx, saved, grad):
992
            return {"x": grad * saved.cos()}
993

994
        x = torch.randn([], requires_grad=True)
995
        op = self.get_op(f"{self.test_ns}::foo")
996
        y = op(x)
997
        self.assertEqual(hit, 1)
998
        y.backward()
999
        self.assertEqual(hit, 1)
1000

1001
    def test_backward_returns_dict(self):
1002
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1003
        def foo(x: torch.Tensor) -> torch.Tensor:
1004
            raise NotImplementedError()
1005

1006
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1007
        def foo_impl(x):
1008
            return x.sin()
1009

1010
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1011
        def foo_save_for_backward(inputs, output):
1012
            return inputs.x
1013

1014
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1015
        def foo_backward(ctx, saved, grad):
1016
            return grad * saved.cos()
1017

1018
        x = torch.randn([], requires_grad=True)
1019
        op = self.get_op(f"{self.test_ns}::foo")
1020
        y = op(x)
1021
        with self.assertRaisesRegex(RuntimeError, "to be a dict"):
1022
            y.backward()
1023

1024
    def test_backward_dict_invalid_keys(self):
1025
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1026
        def foo(x: torch.Tensor) -> torch.Tensor:
1027
            raise NotImplementedError()
1028

1029
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1030
        def foo_impl(x):
1031
            return x.sin()
1032

1033
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1034
        def foo_save_for_backward(inputs, output):
1035
            return inputs.x
1036

1037
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1038
        def foo_backward(ctx, saved, grad):
1039
            return {"x": grad * saved.cos(), "y": None}
1040

1041
        x = torch.randn([], requires_grad=True)
1042
        op = self.get_op(f"{self.test_ns}::foo")
1043
        y = op(x)
1044
        with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
1045
            y.backward()
1046

1047
    def test_backward_dict_grad_for_nontensor(self):
1048
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1049
        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1050
            raise NotImplementedError()
1051

1052
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1053
        def foo_impl(x, dim):
1054
            return x.sin()
1055

1056
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1057
        def foo_save_for_backward(inputs, output):
1058
            return inputs.x
1059

1060
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1061
        def foo_backward(ctx, saved, grad):
1062
            return {"x": grad * saved.cos(), "dim": None}
1063

1064
        x = torch.randn([], requires_grad=True)
1065
        op = self.get_op(f"{self.test_ns}::foo")
1066
        y = op(x, 32)
1067
        with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
1068
            y.backward()
1069

1070
    def test_backward_dict_requires_keys_for_input_tensors(self):
1071
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1072
        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1073
            raise NotImplementedError()
1074

1075
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1076
        def foo_impl(x, y):
1077
            return x.sin()
1078

1079
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1080
        def foo_save_for_backward(inputs, output):
1081
            return inputs.x
1082

1083
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1084
        def foo_backward(ctx, saved, grad):
1085
            return {"x": grad * saved.cos()}
1086

1087
        x = torch.randn([], requires_grad=True)
1088
        op = self.get_op(f"{self.test_ns}::foo")
1089
        y = op(x, x)
1090
        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1091
            y.backward()
1092

1093
    def test_backward_dict_requires_keys_for_input_optional_tensors(self):
1094
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1095
        def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
1096
            raise NotImplementedError()
1097

1098
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1099
        def foo_impl(x, y):
1100
            return x.sin()
1101

1102
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1103
        def foo_save_for_backward(inputs, output):
1104
            return inputs.x
1105

1106
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1107
        def foo_backward(ctx, saved, grad):
1108
            return {"x": grad * saved.cos()}
1109

1110
        x = torch.randn([], requires_grad=True)
1111
        op = self.get_op(f"{self.test_ns}::foo")
1112
        y = op(x, None)
1113
        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1114
            y.backward()
1115

1116
    def test_backward_grads_are_tensor_or_none(self):
1117
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1118
        def foo(x: torch.Tensor) -> torch.Tensor:
1119
            raise NotImplementedError()
1120

1121
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1122
        def foo_impl(x):
1123
            return x.sin()
1124

1125
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1126
        def foo_save_for_backward(inputs, output):
1127
            return inputs.x
1128

1129
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1130
        def foo_backward(ctx, saved, grad):
1131
            return {"x": (grad * saved.cos(),)}
1132

1133
        x = torch.randn([], requires_grad=True)
1134
        op = self.get_op(f"{self.test_ns}::foo")
1135
        y = op(x)
1136
        with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"):
1137
            y.backward()
1138

1139
    def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
1140
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1141
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1142
            raise NotImplementedError()
1143

1144
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1145
        def foo_impl(xs):
1146
            return xs[0].sin()
1147

1148
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1149
        def foo_save_for_backward(inputs, output):
1150
            return inputs.xs[0]
1151

1152
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1153
        def foo_backward(ctx, saved, grad):
1154
            return {"xs": [grad * saved.cos(), None]}
1155

1156
        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1157
        op = self.get_op(f"{self.test_ns}::foo")
1158
        y = op(xs)
1159
        with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
1160
            y.backward()
1161

1162
    def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
1163
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1164
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1165
            raise NotImplementedError()
1166

1167
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1168
        def foo_impl(xs):
1169
            return xs[0].sin()
1170

1171
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1172
        def foo_save_for_backward(inputs, output):
1173
            return inputs.xs[0]
1174

1175
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1176
        def foo_backward(ctx, saved, grad):
1177
            return {"xs": [grad * saved.cos(), None, (None,)]}
1178

1179
        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1180
        op = self.get_op(f"{self.test_ns}::foo")
1181
        y = op(xs)
1182
        with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
1183
            y.backward()
1184

1185
    def test_backward_tensorlist_input_requires_list_grads(self):
1186
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1187
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1188
            raise NotImplementedError()
1189

1190
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1191
        def foo_impl(xs):
1192
            return xs[0].sin()
1193

1194
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1195
        def foo_save_for_backward(inputs, output):
1196
            return inputs.xs[0]
1197

1198
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1199
        def foo_backward(ctx, saved, grad):
1200
            return {"xs": None}
1201

1202
        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1203
        op = self.get_op(f"{self.test_ns}::foo")
1204
        y = op(xs)
1205
        with self.assertRaisesRegex(RuntimeError, "list of gradients"):
1206
            y.backward()
1207

1208
    def test_backward_output_differentiability_type(self):
1209
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1210
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1211
            raise NotImplementedError()
1212

1213
        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1214

1215
            @custom_ops.impl_backward(
1216
                f"{TestCustomOp.test_ns}::foo", output_differentiability=True
1217
            )
1218
            def foo_backward(ctx, saved, grad):
1219
                return {"xs": None}
1220

1221
    def test_backward_output_differentiability_numel(self):
1222
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1223
        def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
1224
            raise NotImplementedError()
1225

1226
        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1227

1228
            @custom_ops.impl_backward(
1229
                f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
1230
            )
1231
            def foo_backward(ctx, saved, grad):
1232
                return {"xs": None}
1233

1234
    def test_backward_output_differentiability_tensorlist(self):
1235
        @custom_ops.custom_op(f"{self.test_ns}::foo")
1236
        def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
1237
            raise NotImplementedError()
1238

1239
        @custom_ops.impl(f"{self.test_ns}::foo")
1240
        def foo_impl(x):
1241
            return [x.clone(), x.clone()], x.clone()
1242

1243
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1244
        def foo_save_for_backward(inputs, output):
1245
            return []
1246

1247
        @custom_ops.impl_backward(
1248
            f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
1249
        )
1250
        def foo_backward(ctx, saved, grad_lst, grad):
1251
            return {"x": grad}
1252

1253
        op = self.get_op(f"{self.test_ns}::foo")
1254
        x = torch.randn(3, requires_grad=True)
1255
        [a, b], c = op(x)
1256
        self.assertFalse(a.requires_grad)
1257
        self.assertFalse(b.requires_grad)
1258
        self.assertTrue(c.requires_grad)
1259

1260
    def test_backward_output_differentiability_non_tensor(self):
1261
        @custom_ops.custom_op(f"{self.test_ns}::foo")
1262
        def foo(x: Tensor) -> Tuple[Tensor, int]:
1263
            raise NotImplementedError()
1264

1265
        @custom_ops.impl(f"{self.test_ns}::foo")
1266
        def foo_impl(x):
1267
            return x.clone(), 3
1268

1269
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1270
        def foo_save_for_backward(inputs, output):
1271
            return []
1272

1273
        @custom_ops.impl_backward(
1274
            f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
1275
        )
1276
        def foo_backward(ctx, saved, grad0, grad1):
1277
            return {"x": grad0}
1278

1279
        op = self.get_op(f"{self.test_ns}::foo")
1280
        x = torch.randn(3, requires_grad=True)
1281
        with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
1282
            op(x)
1283

1284
    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1285
    def test_impl_separate(self):
1286
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1287
        def foo(x: torch.Tensor) -> torch.Tensor:
1288
            raise NotImplementedError()
1289

1290
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1291
        def foo_cpu(x):
1292
            return x.sin()
1293

1294
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda")
1295
        def foo_cuda(x):
1296
            return x.cos()
1297

1298
        x = torch.randn(3)
1299
        op = self.get_op(f"{self.test_ns}::foo")
1300
        result = op(x)
1301
        self.assertEqual(result, foo_cpu(x))
1302

1303
        x_cuda = x.cuda()
1304
        op = self.get_op(f"{self.test_ns}::foo")
1305
        result = op(x_cuda)
1306
        self.assertEqual(result, foo_cuda(x_cuda))
1307

1308
    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1309
    def test_impl_multiple(self):
1310
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1311
        def foo(x: torch.Tensor) -> torch.Tensor:
1312
            raise NotImplementedError()
1313

1314
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1315
        def foo_impl(x):
1316
            return x.cos()
1317

1318
        op = self.get_op(f"{self.test_ns}::foo")
1319
        x = torch.randn(3)
1320
        result = op(x)
1321
        self.assertEqual(result, foo_impl(x))
1322

1323
        x_cuda = x.cuda()
1324
        result = op(x_cuda)
1325
        self.assertEqual(result, foo_impl(x_cuda))
1326

1327
    def test_impl_abstract_overload(self):
1328
        lib = self.lib()
1329
        lib.define("sin.blah(Tensor x) -> Tensor")
1330

1331
        torch.library.impl_abstract(
1332
            f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
1333
        )
1334

1335
        op = self.ns().sin.blah
1336
        x = torch.randn(3, device="meta")
1337
        op(x)
1338

1339
    def test_impl_meta(self):
1340
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1341
        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1342
            raise NotImplementedError()
1343

1344
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1345
        def foo_meta(x, dim):
1346
            output_shape = list(x.shape)
1347
            del output_shape[dim]
1348
            return x.new_empty(output_shape)
1349

1350
        x = torch.randn(2, 3, device="meta")
1351
        op = self.get_op(f"{self.test_ns}::foo")
1352
        result = op(x, 1)
1353
        self.assertEqual(result.shape, foo_meta(x, 1).shape)
1354

1355
    def test_duplicate_impl(self):
1356
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1357
        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1358
            raise NotImplementedError()
1359

1360
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1361
        def foo_meta(x, dim):
1362
            output_shape = list(x.shape)
1363
            del output_shape[dim]
1364
            return x.new_empty(output_shape)
1365

1366
        with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
1367

1368
            @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1369
            def foo_meta2(x, dim):
1370
                output_shape = list(x.shape)
1371
                del output_shape[dim]
1372
                return x.new_empty(output_shape)
1373

1374
    def test_new_data_dependent_symint(self):
1375
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1376
        def foo(x: torch.Tensor) -> torch.Tensor:
1377
            raise NotImplementedError()
1378

1379
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1380
        def foo_meta(x):
1381
            ctx = torch.library.get_ctx()
1382
            ctx.new_dynamic_size(min=1)
1383
            with self.assertRaisesRegex(ValueError, "greater than or equal to 0"):
1384
                ctx.new_dynamic_size(min=-1)
1385
            with self.assertRaisesRegex(ValueError, "SymInt"):
1386
                ctx.new_dynamic_size(max=x.numel())
1387
            return torch.clone(x)
1388

1389
        x = torch.randn(2, 3, device="cpu")
1390
        op = self.get_op(f"{self.test_ns}::foo")
1391
        make_fx(op, tracing_mode="symbolic")(x)
1392

1393
    def test_meta_for_data_dependent_shape_operation(self):
1394
        x = torch.randn(10, device="meta")
1395
        with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
1396
            torch.ops._torch_testing.numpy_nonzero(x)
1397

1398
    def test_basic_make_fx(self):
1399
        # More serious tests are in our CustomOp opinfo db,
1400
        # this one is just a sanity check.
1401
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1402
        def foo(x: torch.Tensor) -> torch.Tensor:
1403
            raise NotImplementedError()
1404

1405
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1406
        def foo_meta(x):
1407
            return x.sum()
1408

1409
        x = torch.randn(3)
1410
        op = self.get_op(f"{self.test_ns}::foo")
1411
        gm = make_fx(op, tracing_mode="symbolic")(x)
1412
        self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code)
1413

1414
    def test_not_implemented_error(self):
1415
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1416
        def foo(x: torch.Tensor) -> torch.Tensor:
1417
            raise NotImplementedError()
1418

1419
        x = torch.randn(3)
1420
        op = self.get_op(f"{self.test_ns}::foo")
1421
        with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
1422
            op(x)
1423

1424
        x = torch.randn(3, device="meta")
1425
        with self.assertRaisesRegex(
1426
            NotImplementedError, "no abstract impl or Meta kernel"
1427
        ):
1428
            op(x)
1429

1430
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
1431
        def bar(sizes: Sequence[int]) -> torch.Tensor:
1432
            raise NotImplementedError()
1433

1434
        op = self.get_op(f"{self.test_ns}::bar")
1435
        with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
1436
            op((1, 2, 3))
1437

1438
    def test_abstract_registration_location(self):
1439
        custom_op = torch._custom_op.impl._find_custom_op(
1440
            "_torch_testing::numpy_nonzero"
1441
        )
1442
        source = torch._library.simple_registry.singleton.find(
1443
            "_torch_testing::numpy_nonzero"
1444
        ).abstract_impl.kernel.source
1445
        self.assertRegex(source, r".*custom_op_db.py:\d+")
1446

1447
    def test_data_dependent_basic(self):
1448
        def f(x):
1449
            return torch.ops._torch_testing.numpy_nonzero(x)
1450

1451
        x = torch.randn(5, 5)
1452
        gm = make_fx(f, tracing_mode="symbolic")(x)
1453
        self.assertTrue("nonzero" in gm.code)
1454

1455
    def test_data_dependent_fake_tracing(self):
1456
        def f(x):
1457
            return torch.ops._torch_testing.numpy_nonzero(x)
1458

1459
        x = torch.randn(5, 5)
1460
        # We've updated to attempt to use unbacked symints even for fake
1461
        # tracing
1462
        make_fx(f, tracing_mode="fake")(x)
1463

1464
    def test_symints(self):
1465
        def f(x):
1466
            return torch.ops._torch_testing.numpy_view_copy(x, x.shape)
1467

1468
        x = torch.randn(2, 3, 4)
1469
        gm = make_fx(f, tracing_mode="symbolic")(x)
1470
        result = gm(x)
1471
        self.assertEqual(result, f(x))
1472
        self.assertExpectedInline(
1473
            gm.code.strip(),
1474
            """\
1475
def forward(self, x_1):
1476
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1477
    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
1478
    sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
1479
    numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
1480
    return numpy_view_copy""",  # noqa: B950
1481
        )
1482

1483
    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
1484
    @unittest.skipIf(
1485
        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
1486
    )
1487
    def test_data_dependent_compile(self):
1488
        import torch._dynamo.testing
1489
        from torch._dynamo.utils import counters
1490

1491
        counters.clear()
1492
        cnt = torch._dynamo.testing.CompileCounter()
1493

1494
        @torch.compile(backend=cnt)
1495
        def f(x):
1496
            return torch.ops._torch_testing.numpy_nonzero(x.clone()).clone()
1497

1498
        f(torch.randn(10))
1499

1500
        self.assertEqual(
1501
            dict(counters["graph_break"]),
1502
            {
1503
                "dynamic shape operator: _torch_testing.numpy_nonzero.default; to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True": 1  # noqa: B950
1504
            },
1505
        )
1506

1507
    # pre-existing problem: torch.compile(dynamic=True) will, by default,
1508
    # graph break on data-dependent operations. Eventually we'll make it so
1509
    # that it never graph breaks on data-dependent operations.
1510
    @unittest.expectedFailure
1511
    def test_data_dependent_nms_dynamic_compile(self):
1512
        import torch._dynamo.testing
1513
        from torch._dynamo.utils import counters
1514

1515
        counters.clear()
1516
        cnt = torch._dynamo.testing.CompileCounter()
1517

1518
        @torch.compile(backend=cnt, dynamic=True)
1519
        def f(x, s, i):
1520
            return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone()
1521

1522
        f(torch.randn(20, 4), torch.randn(20), 0.1)
1523

1524
        self.assertEqual(len(counters["graph_break"]), 0)
1525

1526
    def test_impl_on_existing_op(self):
1527
        lib = self.lib()
1528
        lib.define("foo(Tensor x) -> Tensor")
1529
        qualname = f"{self.test_ns}::foo"
1530

1531
        @torch._custom_ops.impl(qualname)
1532
        def foo_impl(x):
1533
            return x.sin()
1534

1535
        op = self.get_op(qualname)
1536
        x = torch.randn(3)
1537
        result = op(x)
1538
        self.assertEqual(result, x.sin())
1539

1540
    @parametrize(
1541
        "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"]
1542
    )
1543
    def test_impl_on_existing_op_with_cpu_registration(self, key):
1544
        lib = self.lib()
1545
        lib.define("foo(Tensor x) -> Tensor")
1546
        qualname = f"{self.test_ns}::foo"
1547

1548
        def foo_impl(x):
1549
            return x.sin()
1550

1551
        lib.impl("foo", foo_impl, key)
1552
        op = self.get_op(qualname)
1553

1554
        with self.assertRaisesRegex(RuntimeError, "already has an implementation"):
1555
            custom_ops.impl(qualname, func=foo_impl)
1556

1557
    def test_abstract_impl_on_existing_op(self):
1558
        lib = self.lib()
1559
        lib.define("foo(Tensor x) -> Tensor")
1560
        qualname = f"{self.test_ns}::foo"
1561

1562
        @torch.library.impl_abstract(qualname, lib=self.lib())
1563
        def foo_impl(x):
1564
            return x.sin()
1565

1566
        op = self.get_op(qualname)
1567
        with torch._subclasses.FakeTensorMode():
1568
            x = torch.randn(3)
1569
            result = op(x)
1570
            self.assertEqual(result.shape, x.shape)
1571
            self.assertEqual(result.stride(), x.stride())
1572

1573
    def test_abstract_impl_on_existing_op_with_meta(self):
1574
        lib = self.lib()
1575
        lib.define("foo(Tensor x) -> Tensor")
1576
        qualname = f"{self.test_ns}::foo"
1577

1578
        def foo_impl(x):
1579
            return x.sin()
1580

1581
        lib.impl("foo", foo_impl, "Meta")
1582
        op = self.get_op(qualname)
1583

1584
        with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
1585
            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1586

1587
    def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
1588
        lib = self.lib()
1589
        lib.define("foo(Tensor x) -> Tensor")
1590
        qualname = f"{self.test_ns}::foo"
1591

1592
        def foo_impl(x):
1593
            return x.sin()
1594

1595
        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
1596
        op = self.get_op(qualname)
1597

1598
        with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
1599
            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1600

1601
    def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
1602
        lib = self.lib()
1603
        lib.define("foo(Tensor x) -> Tensor")
1604
        qualname = f"{self.test_ns}::foo"
1605

1606
        def foo_impl(x):
1607
            return x.sin()
1608

1609
        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
1610
        op = self.get_op(qualname)
1611

1612
        torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
1613
        with torch._subclasses.FakeTensorMode():
1614
            x = torch.randn(10)
1615
            result = op(x)
1616
            self.assertEqual(result.shape, ())
1617

1618
    def _test_backward_impl_raises(self, qualname, err_regex):
1619
        with self.assertRaisesRegex(RuntimeError, err_regex):
1620

1621
            @custom_ops.impl_save_for_backward(qualname)
1622
            def foo2(x):
1623
                return
1624

1625
        with self.assertRaisesRegex(RuntimeError, err_regex):
1626

1627
            @custom_ops.impl_backward(qualname)
1628
            def foo3(x):
1629
                return
1630

1631
    def test_backward_impl_on_existing_op_incorrect_schema_views(self):
1632
        lib = self.lib()
1633
        lib.define("foo(Tensor(a) x) -> Tensor(a)")
1634
        qualname = f"{self.test_ns}::foo"
1635
        self._test_backward_impl_raises(qualname, "operator that returns views")
1636

1637
    def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
1638
        lib = self.lib()
1639
        lib.define("foo(Tensor(a!) x) -> Tensor")
1640
        qualname = f"{self.test_ns}::foo"
1641
        self._test_backward_impl_raises(qualname, "non-functional")
1642

1643
    def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
1644
        lib = self.lib()
1645
        lib.define("foo(Tensor x) -> ()")
1646
        qualname = f"{self.test_ns}::foo"
1647
        self._test_backward_impl_raises(qualname, "no returns")
1648

1649
    def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
1650
        lib = self.lib()
1651
        lib.define("foo(Tensor x) -> Tensor")
1652
        qualname = f"{self.test_ns}::foo"
1653
        lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
1654
        self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
1655

1656
    @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
1657
    def test_backward_impl_on_existing_op_with_key(self, key):
1658
        lib = self.lib()
1659
        lib.define("foo(Tensor x) -> Tensor")
1660
        qualname = f"{self.test_ns}::foo"
1661
        lib.impl("foo", lambda x: x.sin().cos(), key)
1662
        self._test_backward_impl_raises(qualname, key)
1663

1664
    def test_backward_impl_on_existing_op(self):
1665
        lib = self.lib()
1666
        lib.define("foo(Tensor x) -> Tensor")
1667
        qualname = f"{self.test_ns}::foo"
1668

1669
        @custom_ops.impl(qualname)
1670
        def foo_impl(x):
1671
            with torch.no_grad():
1672
                return x.sin()
1673

1674
        @custom_ops.impl_save_for_backward(qualname)
1675
        def foo_save_for_backward(inputs, output):
1676
            return inputs.x
1677

1678
        @custom_ops.impl_backward(qualname)
1679
        def foo_backward(ctx, saved, grad_out):
1680
            return {"x": grad_out * saved.cos()}
1681

1682
        op = self.get_op(qualname)
1683
        x = torch.randn([], requires_grad=True)
1684
        y = op(x)
1685
        (gx,) = torch.autograd.grad(y, x)
1686
        self.assertEqual(gx, x.cos())
1687

1688
    @parametrize(
1689
        "tags",
1690
        [
1691
            subtest(torch.Tag.pointwise, "single"),
1692
            subtest((torch.Tag.pointwise,), "tuple"),
1693
            subtest([torch.Tag.pointwise], "list"),
1694
        ],
1695
    )
1696
    def test_define_with_tags(self, tags):
1697
        lib = self.lib()
1698
        tags = (torch.Tag.pointwise,)
1699
        torch.library.define(
1700
            f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags
1701
        )
1702
        actual = self.ns().foo.default.tags
1703
        self.assertTrue(isinstance(actual, list))
1704
        self.assertEqual(actual, list(tags))
1705

1706
    def test_builtin_aten_ops_are_pt2_compliant(self):
1707
        for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
1708
            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1709

1710
    def test_builtin_torchscript_ops(self):
1711
        for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]:
1712
            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1713

1714
    def test_autogen_aten_ops_are_pt2_compliant(self):
1715
        for op in [
1716
            torch.ops.aten._foreach_copy.default,
1717
            torch.ops.aten.fill.Tensor_out,
1718
        ]:
1719
            self.assertIn(torch.Tag.generated, op.tags)
1720
            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1721

1722
    def test_resolve_packet(self):
1723
        x = torch.randn(3)
1724
        result = torch._C._jit_resolve_packet("aten::sum", x)
1725
        self.assertEqual(result, "default")
1726

1727
        result = torch._C._jit_resolve_packet("aten::sum", x, dim=1)
1728
        self.assertEqual(result, "dim_IntList")
1729

1730
        with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1731
            result = torch._C._jit_resolve_packet("aten::sum", x, x, x)
1732

1733
    def test_define_bad_schema(self):
1734
        lib = self.lib()
1735
        with self.assertRaisesRegex(ValueError, "expected schema to look like"):
1736
            torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor")
1737

1738
    def test_define_and_impl(self):
1739
        lib = self.lib()
1740
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1741

1742
        @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib)
1743
        def f(x):
1744
            return torch.from_numpy(np.sin(x.numpy()))
1745

1746
        x = torch.randn(3)
1747
        y = self.ns().foo(x)
1748
        assert torch.allclose(y, x.sin())
1749

1750
    def test_define_validation(self):
1751
        with self.assertRaisesRegex(ValueError, "namespace"):
1752
            torch.library.define("foo", "(Tensor x) -> Tensor")
1753

1754
    def test_legacy_define(self):
1755
        lib = self.lib()
1756

1757
        @torch.library.define(lib, "foo(Tensor x) -> Tensor")
1758
        def f(x):
1759
            return torch.from_numpy(np.sin(x.numpy()))
1760

1761
        x = torch.randn(3)
1762
        y = self.ns().foo(x)
1763
        assert torch.allclose(y, x.sin())
1764

1765
    def test_impl_function(self):
1766
        lib = self.lib()
1767
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1768

1769
        def f(x):
1770
            return torch.from_numpy(np.sin(x.numpy()))
1771

1772
        torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib)
1773
        x = torch.randn(3)
1774
        y = self.ns().foo(x)
1775
        assert torch.allclose(y, x.sin())
1776

1777
    def test_legacy_impl(self):
1778
        lib = self.lib()
1779
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1780

1781
        @torch.library.impl(lib, "foo", "CPU")
1782
        def f(x):
1783
            return torch.from_numpy(np.sin(x.numpy()))
1784

1785
        x = torch.randn(3)
1786
        y = self.ns().foo(x)
1787
        assert torch.allclose(y, x.sin())
1788

1789
    def test_defined_in_python(self):
1790
        self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
1791
        self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
1792

1793
        lib = self.lib()
1794
        torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1795
        ns = self.ns()
1796
        self.assertTrue(ns.foo.default._defined_in_python)
1797

1798
        torch.library.define(
1799
            "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
1800
        )
1801
        self.assertTrue(ns.bar.overload._defined_in_python)
1802

1803
    def _test_impl_device(self, name, types, device):
1804
        lib = self.lib()
1805
        torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)
1806

1807
        @torch.library.impl(f"{self.test_ns}::{name}", types)
1808
        def f(x):
1809
            x_np = x.cpu().numpy()
1810
            y = torch.from_numpy(np.sin(x_np))
1811
            return y.to(device=x.device)
1812

1813
        x = torch.randn(3, device=device)
1814
        y = getattr(self.ns(), name)(x)
1815
        assert torch.allclose(y, x.sin())
1816

1817
    def test_impl_device_cpu(self):
1818
        self._test_impl_device("foo1", "default", "cpu")
1819
        self._test_impl_device("foo2", ["cpu"], "cpu")
1820
        self._test_impl_device("foo3", ["cpu", "cuda"], "cpu")
1821

1822
    @unittest.skipIf(not TEST_CUDA, "requires cuda")
1823
    def test_impl_device_cuda(self):
1824
        self._test_impl_device("foo4", "default", "cuda")
1825
        self._test_impl_device("foo5", ["cuda"], "cuda")
1826
        self._test_impl_device("foo6", ["cpu", "cuda"], "cuda")
1827

1828
    def test_impl_device_function(self):
1829
        lib = self.lib()
1830
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1831

1832
        def f(x):
1833
            x_np = x.cpu().numpy()
1834
            y = torch.from_numpy(np.sin(x_np))
1835
            return y.to(device=x.device)
1836

1837
        torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib)
1838
        x = torch.randn(3)
1839
        y = self.ns().foo(x)
1840
        assert torch.allclose(y, x.sin())
1841

1842
    def test_impl_device_invalid(self):
1843
        with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
1844
            torch.library.impl("blah::blah", "somethingsomething")
1845

1846
    def test_autograd_function_backed_op(self):
1847
        cpp_source = """
1848
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
1849
  static constexpr bool is_traceable = true;
1850

1851
  static torch::Tensor forward(
1852
      torch::autograd::AutogradContext* ctx,
1853
      torch::Tensor x) {
1854
    return x;
1855
  }
1856

1857
  static torch::autograd::variable_list backward(
1858
      torch::autograd::AutogradContext *ctx,
1859
      torch::autograd::variable_list grad_output) {
1860
    return grad_output;
1861
  }
1862
};
1863

1864
torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
1865
  return CustomOpAutogradFunction::apply(x);
1866
}
1867

1868
TORCH_LIBRARY(mylib, m) {
1869
    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
1870
}
1871
        """
1872

1873
        module = torch.utils.cpp_extension.load_inline(
1874
            name="mylib",
1875
            cpp_sources=cpp_source,
1876
            functions="custom_op_backed_by_autograd_fn",
1877
            verbose=True,
1878
        )
1879

1880
        x = torch.ones(2, 2, requires_grad=True)
1881
        temp = x.clone().detach()
1882
        out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
1883
        loss = out.sum()
1884
        loss.backward()
1885
        self.assertEqual(x.grad, temp)
1886

1887

1888
def op_with_incorrect_schema(testcase, name):
1889
    lib = testcase.lib()
1890
    lib.define(f"{name}(Tensor x) -> Tensor")
1891
    qualname = f"{testcase.test_ns}::{name}"
1892
    lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd")
1893
    return testcase.get_op(qualname)
1894

1895

1896
class MiniOpTest(CustomOpTestCaseBase):
1897
    test_ns = "mini_op_test"
1898

1899
    def _init_op_delayed_backward_error(self):
1900
        name = "delayed_error"
1901
        qualname = f"{self.test_ns}::{name}"
1902
        lib = self.lib()
1903
        lib.define(f"{name}(Tensor x) -> Tensor")
1904
        lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd")
1905
        op = self.get_op(qualname)
1906

1907
        class Op(torch.autograd.Function):
1908
            @staticmethod
1909
            def forward(ctx, x):
1910
                with torch._C._AutoDispatchBelowAutograd():
1911
                    return op(x)
1912

1913
            @staticmethod
1914
            def backward(ctx, grad):
1915
                raise NotImplementedError()
1916

1917
        def autograd_impl(x):
1918
            return Op.apply(x)
1919

1920
        lib.impl(name, autograd_impl, "Autograd")
1921
        return op
1922

1923
    def _init_op_with_no_abstract_impl(self):
1924
        name = "no_abstract"
1925
        qualname = f"{self.test_ns}::{name}"
1926
        lib = self.lib()
1927
        lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,))
1928
        lib.impl(name, lambda x: x.clone(), "CPU")
1929
        return torch._library.utils.lookup_op(qualname)
1930

1931
    def setUp(self):
1932
        super().setUp()
1933
        self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl()
1934
        self._op_delayed_backward_error = self._init_op_delayed_backward_error()
1935

1936
    @optests.dontGenerateOpCheckTests("Testing this API")
1937
    def test_dont_generate(self):
1938
        op = op_with_incorrect_schema(self, "incorrect_schema")
1939
        x = torch.randn(3)
1940
        op(x)
1941

1942
    def test_mm(self):
1943
        x = torch.randn(2, 3, requires_grad=True)
1944
        y = torch.randn(3, 5)
1945
        result = torch.ops.aten.mm.default(x, y)
1946
        self.assertEqual(result, x @ y)
1947

1948
    def test_mm_meta(self):
1949
        x = torch.randn(2, 3, requires_grad=True, device="meta")
1950
        y = torch.randn(3, 5, device="meta")
1951
        result = torch.ops.aten.mm.default(x, y)
1952
        self.assertEqual(result.shape, (x @ y).shape)
1953

1954
    def test_mm_fake(self):
1955
        with torch._subclasses.fake_tensor.FakeTensorMode():
1956
            x = torch.randn(2, 3, requires_grad=True, device="cpu")
1957
            y = torch.randn(3, 5, device="cpu")
1958
            result = torch.ops.aten.mm.default(x, y)
1959
            self.assertEqual(result.shape, (x @ y).shape)
1960

1961
    def test_mm_errors(self):
1962
        x = torch.randn(2, 3, requires_grad=True)
1963
        y = torch.randn(4, 5)
1964
        with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"):
1965
            result = torch.ops.aten.mm.default(x, y)
1966

1967
    def test_nonzero(self):
1968
        x = torch.tensor([0, 1, 2, 0, 0])
1969
        y = torch.ops.aten.nonzero.default(x)
1970
        self.assertEqual(y, torch.tensor([[1], [2]]))
1971

1972
    def test_inplace(self):
1973
        x = torch.randn(3)
1974
        x_clone = x.clone()
1975
        y = torch.ops.aten.sin_(x)
1976
        self.assertEqual(x, x_clone.sin())
1977

1978
    def test_incorrect_schema(self):
1979
        op = op_with_incorrect_schema(self, "incorrect_schema")
1980
        x = torch.randn(3)
1981
        op(x)
1982

1983
    def test_no_abstract(self):
1984
        op = self._op_with_no_abstract_impl
1985
        x = torch.randn(3)
1986
        op(x)
1987

1988
    def test_delayed_error(self):
1989
        op = self._op_delayed_backward_error
1990
        x = torch.randn([], requires_grad=True)
1991
        y = op(x)
1992
        with self.assertRaises(NotImplementedError):
1993
            y.sum().backward()
1994

1995
    def test_delayed_error_no_requires_grad(self):
1996
        op = self._op_delayed_backward_error
1997
        x = torch.randn([])
1998
        y = op(x)
1999

2000

2001
class MiniOpTestOther(CustomOpTestCaseBase):
2002
    test_ns = "mini_op_test"
2003

2004
    def test_nonzero_again(self):
2005
        x = torch.tensor([0, 1, 2, 0, 0])
2006
        y = torch.ops.aten.nonzero.default(x)
2007
        self.assertEqual(y, torch.tensor([[1], [2]]))
2008

2009

2010
optests.generate_opcheck_tests(
2011
    MiniOpTest,
2012
    ["aten", "mini_op_test"],
2013
    get_file_path_2(
2014
        os.path.dirname(__file__),
2015
        "minioptest_failures_dict.json",
2016
    ),
2017
    additional_decorators={
2018
        "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure]
2019
    },
2020
)
2021

2022
optests.generate_opcheck_tests(
2023
    MiniOpTestOther,
2024
    ["aten", "mini_op_test"],
2025
    get_file_path_2(
2026
        os.path.dirname(__file__),
2027
        "minioptest_failures_dict.json",
2028
    ),
2029
)
2030

2031

2032
class TestGenerateOpcheckTests(CustomOpTestCaseBase):
2033
    def test_MiniOpTest(self):
2034
        for orig_test in ["test_mm", "test_nonzero"]:
2035
            for (
2036
                test
2037
            ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS:
2038
                expected_test = f"{test}__{orig_test}"
2039
                self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test)
2040

2041
    def test_generate_repro_save_data(self):
2042
        from torch.testing._internal.optests.generate_tests import generate_repro
2043

2044
        args = (torch.ones(2, 2),)
2045
        kwargs = {"mat2": torch.zeros(2, 2)}
2046
        actual = generate_repro(
2047
            "test_schema",
2048
            torch.ops.aten.sin.default,
2049
            args,
2050
            kwargs,
2051
            save_data=True,
2052
            dry_run=True,
2053
        )
2054
        actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual)
2055
        self.assertExpectedInline(
2056
            actual,
2057
            """\
2058
# =========================================================
2059
# BEGIN REPRO SCRIPT
2060
# =========================================================
2061
import torch
2062
from torch.testing._internal.optests import opcheck
2063

2064
# Make sure you have loaded the library that contains the op
2065
# via an import or torch.ops.load_library(...)
2066
op = torch.ops.aten.sin.default
2067

2068
args, kwargs = torch.load("repro.pt")
2069
opcheck(op, args, kwargs, test_utils="test_schema")
2070
# =========================================================
2071
# END REPRO SCRIPT
2072
# =========================================================
2073
""",
2074
        )
2075

2076
    def test_generate_repro_no_save_data(self):
2077
        from torch.testing._internal.optests.generate_tests import generate_repro
2078

2079
        args = (torch.ones(2, 2),)
2080
        kwargs = {"mat2": torch.zeros(2, 2)}
2081
        actual = generate_repro(
2082
            "test_schema",
2083
            torch.ops.aten.sin.default,
2084
            args,
2085
            kwargs,
2086
            save_data=False,
2087
            dry_run=True,
2088
        )
2089
        self.assertExpectedInline(
2090
            actual,
2091
            """\
2092
# =========================================================
2093
# BEGIN REPRO SCRIPT
2094
# =========================================================
2095
import torch
2096
from torch.testing._internal.optests import opcheck
2097

2098
# Make sure you have loaded the library that contains the op
2099
# via an import or torch.ops.load_library(...)
2100
op = torch.ops.aten.sin.default
2101

2102
# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1
2103
# we will fill them in same (args, kwargs) as in your test
2104
args = ()  # args to the operator
2105
kwargs = {}  # kwargs to the operator
2106
opcheck(op, args, kwargs, test_utils="test_schema")
2107
# =========================================================
2108
# END REPRO SCRIPT
2109
# =========================================================
2110
""",
2111
        )
2112

2113
    def test_failures_dict_validation(self):
2114
        from torch.testing._internal.optests.generate_tests import (
2115
            FailuresDict,
2116
            validate_failures_dict_structure,
2117
        )
2118

2119
        failures = {
2120
            "mini_op_test::incorrect_schema": {
2121
                "MiniOpTest.test_aot_dispatch_static__test_delayed_error": {
2122
                    "comment": "",
2123
                    "status": "success",
2124
                }
2125
            }
2126
        }
2127
        with self.assertRaisesRegex(RuntimeError, "got status=success"):
2128
            validate_failures_dict_structure(
2129
                FailuresDict("", failures),
2130
                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
2131
                MiniOpTest,
2132
            )
2133

2134
        failures = {
2135
            "mini_op_test::incorrect_schema": {
2136
                "MiniOpTest.test_aot_dispatch__test_delayed_error": {
2137
                    "comment": "",
2138
                    "status": "xfail",
2139
                },
2140
            }
2141
        }
2142
        with self.assertRaisesRegex(RuntimeError, "should begin with one of"):
2143
            validate_failures_dict_structure(
2144
                FailuresDict("", failures),
2145
                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
2146
                MiniOpTest,
2147
            )
2148

2149
        failures = {
2150
            "mini_op_test::incorrect_schema": {
2151
                "MiniOpTest.test_aot_dispatch_static__test_delayed_error_nopenopenope": {
2152
                    "comment": "",
2153
                    "status": "xfail",
2154
                },
2155
            }
2156
        }
2157
        with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"):
2158
            validate_failures_dict_structure(
2159
                FailuresDict("", failures),
2160
                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
2161
                MiniOpTest,
2162
            )
2163

2164
    def test_dont_generate_decorator(self):
2165
        self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
2166
        self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
2167

2168
    def test_opcheck(self):
2169
        x = torch.randn(3, requires_grad=True)
2170
        with self.assertRaisesRegex(ValueError, "OpOverload"):
2171
            optests.opcheck(torch.sin, (x,))
2172
        with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
2173
            optests.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
2174
        result = optests.opcheck(torch.ops.aten.sin.default, (x,))
2175

2176
        self.assertEqual(
2177
            result,
2178
            {
2179
                "test_schema": "SUCCESS",
2180
                "test_autograd_registration": "SUCCESS",
2181
                "test_faketensor": "SUCCESS",
2182
                "test_aot_dispatch_static": "SUCCESS",
2183
                "test_aot_dispatch_dynamic": "SUCCESS",
2184
            },
2185
        )
2186

2187
        result = optests.opcheck(
2188
            torch.ops.aten.sin.default, (x,), test_utils="test_schema"
2189
        )
2190
        self.assertEqual(
2191
            result,
2192
            {
2193
                "test_schema": "SUCCESS",
2194
            },
2195
        )
2196

2197
        result = optests.opcheck(
2198
            torch.ops.aten.sin.default,
2199
            (x,),
2200
            test_utils=["test_schema", "test_faketensor"],
2201
        )
2202
        self.assertEqual(
2203
            result,
2204
            {
2205
                "test_schema": "SUCCESS",
2206
                "test_faketensor": "SUCCESS",
2207
            },
2208
        )
2209

2210
    def test_is_inside_opcheck_mode(self):
2211
        self.assertFalse(optests.is_inside_opcheck_mode())
2212
        with optests.generate_tests.OpCheckMode(
2213
            ["foo"], "bar", lambda x: x, None, "baz", "brr"
2214
        ):
2215
            self.assertTrue(optests.is_inside_opcheck_mode())
2216

2217
    def test_opcheck_bad_op(self):
2218
        op = op_with_incorrect_schema(self, "foo")
2219
        x = torch.randn(3)
2220
        with self.assertRaisesRegex(Exception, "is not defined to alias output"):
2221
            optests.opcheck(op, (x,))
2222

2223
        result = optests.opcheck(op, (x,), raise_exception=False)
2224
        self.assertTrue(isinstance(result["test_schema"], RuntimeError))
2225
        del result["test_schema"]
2226
        self.assertEqual(
2227
            result,
2228
            {
2229
                "test_autograd_registration": "SUCCESS",
2230
                "test_faketensor": "SUCCESS",
2231
                "test_aot_dispatch_static": "SUCCESS",
2232
                "test_aot_dispatch_dynamic": "SUCCESS",
2233
            },
2234
        )
2235

2236

2237
only_for = ("cpu", "cuda")
2238
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
2239
instantiate_parametrized_tests(TestCustomOp)
2240

2241
if __name__ == "__main__":
2242
    run_tests()
2243

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.