pytorch

Форк
0
/
test_custom_ops.py 
3920 строк · 135.4 Кб
1
# Owner(s): ["module: custom-operators"]
2

3
import collections
4
import itertools
5
import os
6
import re
7
import subprocess
8
import sys
9
import typing
10
import unittest
11
from typing import *  # noqa: F403
12

13
import numpy as np
14

15
import torch._custom_ops as custom_ops
16
import torch.testing._internal.optests as optests
17
import torch.utils._pytree as pytree
18
import torch.utils.cpp_extension
19
from functorch import make_fx
20
from torch import Tensor
21
from torch._custom_op.impl import CustomOp, infer_schema
22
from torch._library.infer_schema import tuple_to_list
23
from torch._utils_internal import get_file_path_2
24
from torch.testing._internal import custom_op_db
25
from torch.testing._internal.common_cuda import TEST_CUDA
26
from torch.testing._internal.common_device_type import (
27
    instantiate_device_type_tests,
28
    OpDTypes,
29
    ops,
30
)
31
from torch.testing._internal.common_utils import (
32
    instantiate_parametrized_tests,
33
    IS_WINDOWS,
34
    parametrize,
35
    run_tests,
36
    skipIfTorchDynamo,
37
    subtest,
38
    TestCase,
39
)
40
from torch.testing._internal.custom_op_db import numpy_nonzero
41

42

43
# Shadowed by `torch.testing._internal.common_utils.custom_op`
44
from torch._custom_op.impl import custom_op  # usort: skip
45

46

47
def requires_compile(fun):
48
    fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun)
49
    return fun
50

51

52
class CustomOpTestCaseBase(TestCase):
53
    test_ns = "_test_custom_op"
54

55
    def setUp(self):
56
        super().setUp()
57
        self.libraries = []
58

59
    def tearDown(self):
60
        super().tearDown()
61
        import torch._custom_op
62

63
        keys = list(torch._custom_op.impl.global_registry.keys())
64
        for key in keys:
65
            if not key.startswith(f"{self.test_ns}::"):
66
                continue
67
            torch._custom_op.impl.global_registry[key]._destroy()
68
        if hasattr(torch.ops, self.test_ns):
69
            delattr(torch.ops, self.test_ns)
70
        for lib in self.libraries:
71
            lib._destroy()
72
        del self.libraries
73

74
    def ns(self):
75
        return getattr(torch.ops, self.test_ns)
76

77
    def lib(self):
78
        result = torch.library.Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
79
        self.libraries.append(result)
80
        return result
81

82
    def get_op(self, qualname):
83
        return torch._custom_op.impl.get_op(qualname)
84

85

86
@requires_compile
87
class TestCustomOpTesting(CustomOpTestCaseBase):
88
    @parametrize("check_gradients", (False, "auto"))
89
    @parametrize("dynamic", (True, False))
90
    def test_aot_autograd_check_degenerate_cases(
91
        self, device, dynamic, check_gradients
92
    ):
93
        def simple(x):
94
            return x.clone()
95

96
        # Should not raise
97
        x = torch.randn(3, device=device)
98
        optests.aot_autograd_check(
99
            simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
100
        )
101

102
        def outputs_dont_require_grad(x):
103
            return x.detach()
104

105
        # Should not raise
106
        y = torch.randn(3, device=device, requires_grad=True)
107
        optests.aot_autograd_check(
108
            simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
109
        )
110

111
        def no_outputs(x):
112
            return x.detach()
113

114
        # Should not raise
115
        x = torch.randn(3, device=device, requires_grad=True)
116
        y = torch.randn(3, device=device, requires_grad=False)
117
        optests.aot_autograd_check(
118
            no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients
119
        )
120
        optests.aot_autograd_check(
121
            no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients
122
        )
123

124
    def test_incorrect_schema_mutation(self, device):
125
        lib = self.lib()
126
        lib.define("foo(Tensor x) -> Tensor")
127
        op = self.ns().foo.default
128

129
        class Foo(torch.autograd.Function):
130
            @staticmethod
131
            def forward(ctx, x):
132
                guard = torch._C._AutoDispatchBelowAutograd()
133
                try:
134
                    return op(x)
135
                finally:
136
                    del guard
137

138
            @staticmethod
139
            def backward(ctx, gx):
140
                return gx
141

142
        def foo_impl(x):
143
            x.sin_()
144
            return x.clone()
145

146
        lib.impl("foo", Foo.apply, "Autograd")
147
        lib.impl("foo", foo_impl, "CPU")
148
        lib.impl("foo", foo_impl, "CUDA")
149

150
        x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
151
        with self.assertRaisesRegex(
152
            optests.OpCheckError, "Argument x is not defined as mutable but was mutated"
153
        ):
154
            torch.library.opcheck(op, (x,), {})
155

156
    def test_incorrect_schema_view(self, device):
157
        lib = self.lib()
158
        lib.define("foo(Tensor x) -> Tensor")
159
        op = self.ns().foo.default
160

161
        class Foo(torch.autograd.Function):
162
            @staticmethod
163
            def forward(ctx, x):
164
                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
165
                with torch._C._AutoDispatchBelowAutograd():
166
                    with torch._C._ExcludeDispatchKeyGuard(
167
                        torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
168
                    ):
169
                        return op(x)
170

171
            @staticmethod
172
            def backward(ctx, gx):
173
                return gx
174

175
        def foo_impl(x):
176
            return x.view_as(x)
177

178
        def foo_meta(x):
179
            return x.view_as(x)
180

181
        lib.impl("foo", Foo.apply, "Autograd")
182
        lib.impl("foo", foo_impl, "CPU")
183
        lib.impl("foo", foo_meta, "Meta")
184

185
        x = torch.tensor(3.14159 / 3, requires_grad=True)
186
        with self.assertRaisesRegex(
187
            optests.OpCheckError,
188
            "Argument x is not defined to alias output but was aliasing",
189
        ):
190
            torch.library.opcheck(op, (x,), {})
191

192
    def test_missing_abstract_impl(self, device):
193
        lib = self.lib()
194
        lib.define("foo(Tensor x) -> Tensor")
195
        op = self.ns().foo.default
196

197
        class Foo(torch.autograd.Function):
198
            @staticmethod
199
            def forward(ctx, x):
200
                with torch._C._AutoDispatchBelowAutograd():
201
                    return op(x)
202

203
            @staticmethod
204
            def backward(ctx, gx):
205
                return 2 * gx
206

207
        def foo_impl(x):
208
            return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
209

210
        lib.impl("foo", Foo.apply, "Autograd")
211
        lib.impl("foo", foo_impl, "CPU")
212
        lib.impl("foo", foo_impl, "CUDA")
213

214
        x = torch.tensor([0, 1.0], requires_grad=True)
215
        with self.assertRaisesRegex(
216
            optests.OpCheckError,
217
            "_test_custom_op.foo.default",
218
        ):
219
            torch.library.opcheck(op, (x,), {})
220

221
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
222
    def test_incorrect_abstract_impl(self, device):
223
        lib = self.lib()
224
        lib.define("foo(Tensor x) -> Tensor")
225
        op = self.ns().foo.default
226

227
        class Foo(torch.autograd.Function):
228
            @staticmethod
229
            def forward(ctx, x):
230
                # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
231
                guard = torch._C._AutoDispatchBelowAutograd()
232
                guard2 = torch._C.ExcludeDispatchKeyGuard(
233
                    torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
234
                )
235
                try:
236
                    return op(x)
237
                finally:
238
                    del guard
239
                    del guard2
240

241
            @staticmethod
242
            def backward(ctx, gx):
243
                return gx
244

245
        def foo_impl(x):
246
            return x**2
247

248
        def foo_meta(x):
249
            return x.unsqueeze(1) ** 2
250

251
        lib.impl("foo", Foo.apply, "Autograd")
252
        lib.impl("foo", foo_impl, "CPU")
253
        lib.impl("foo", foo_impl, "CUDA")
254
        lib.impl("foo", foo_meta, "Meta")
255

256
        x = torch.tensor([0, 1.0], requires_grad=True)
257
        with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"):
258
            torch.library.opcheck(op, (x,), {})
259

260
    def test_missing_functionalization(self, device):
261
        lib = self.lib()
262
        lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
263
        op = self.ns().foo.default
264

265
        class Foo(torch.autograd.Function):
266
            @staticmethod
267
            def forward(ctx, x):
268
                ctx.mark_dirty(x)
269
                with torch._C._AutoDispatchBelowAutograd():
270
                    return op(x)
271

272
            @staticmethod
273
            def backward(ctx, gx):
274
                return gx
275

276
        def foo_impl(x):
277
            return x.sin_()
278

279
        def foo_meta(x):
280
            return x
281

282
        lib.impl("foo", Foo.apply, "Autograd")
283
        lib.impl("foo", foo_impl, "CPU")
284
        lib.impl("foo", foo_impl, "CUDA")
285
        lib.impl("foo", foo_meta, "Meta")
286

287
        x = torch.tensor([0, 1.0])
288
        y = x.clone()
289
        with self.assertRaisesRegex(
290
            optests.OpCheckError,
291
            "We only support functionalizing operators whose outputs do not have alias annotations",
292
        ):
293
            torch.library.opcheck(op, (y,), {})
294

295
    def test_autograd_registered_at_backend(self, device):
296
        lib = self.lib()
297
        lib.define("foo(Tensor x) -> Tensor")
298
        op = self.ns().foo.default
299

300
        class Foo(torch.autograd.Function):
301
            @staticmethod
302
            def forward(ctx, x):
303
                return x.clone()
304

305
            @staticmethod
306
            def backward(ctx, gx):
307
                return gx * 0.5
308

309
        lib.impl("foo", Foo.apply, "CPU")
310
        lib.impl("foo", Foo.apply, "CUDA")
311
        lib.impl("foo", lambda x: x.clone(), "Meta")
312

313
        x = torch.randn([], requires_grad=True)
314

315
        with self.assertRaisesRegex(
316
            torch.testing._internal.optests.OpCheckError,
317
            "does not have an autograd kernel",
318
        ):
319
            torch.library.opcheck(op, (x,), {})
320

321
        # I'm not sure why this is necessary
322
        del lib
323

324
    def test_global_state_mutation(self, device):
325
        lib = self.lib()
326
        lib.define("foo(Tensor x) -> Tensor")
327
        op = self.ns().foo.default
328

329
        class Foo(torch.autograd.Function):
330
            invoked = 0
331

332
            @staticmethod
333
            def forward(ctx, x):
334
                Foo.invoked += 1
335
                return x.clone() * Foo.invoked
336

337
            @staticmethod
338
            def backward(ctx, gx):
339
                return gx
340

341
        lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
342

343
        x = torch.tensor(3.14159 / 3, requires_grad=True)
344
        with self.assertRaisesRegex(
345
            optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd"
346
        ):
347
            torch.library.opcheck(op, (x,), {})
348

349
    @ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one)
350
    def test_opcheck_opinfo(self, device, dtype, op):
351
        for sample_input in op.sample_inputs(
352
            device, dtype, requires_grad=op.supports_autograd
353
        ):
354
            args = [sample_input.input] + list(sample_input.args)
355
            kwargs = sample_input.kwargs
356
            torch.library.opcheck(op.op, args, kwargs)
357

358
    def test_opcheck_fails_basic(self, device):
359
        @custom_op(f"{self.test_ns}::foo")
360
        def foo(x: torch.Tensor) -> torch.Tensor: ...
361

362
        @foo.impl(["cpu", "cuda"])
363
        def foo_impl(x):
364
            return x.sum()
365

366
        x = torch.randn(3, device=device, requires_grad=True)
367
        # Triggers the CustomOp autograd NYI error
368
        with self.assertRaisesRegex(
369
            optests.OpCheckError, "Autograd has not been implemented for operator"
370
        ):
371
            torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {})
372

373
    def test_autograd_registration_check_autograd_kernel(self, device):
374
        lib = self.lib()
375
        lib.define("foo(Tensor x) -> Tensor")
376
        op = self.ns().foo.default
377

378
        class Foo(torch.autograd.Function):
379
            @staticmethod
380
            def forward(ctx, x):
381
                with torch._C._AutoDispatchBelowAutograd():
382
                    return op(x)
383

384
            @staticmethod
385
            def backward(ctx, gx):
386
                return gx
387

388
        def foo_impl(x):
389
            return x.sin()
390

391
        lib.impl("foo", Foo.apply, "Autograd")
392
        lib.impl("foo", foo_impl, "CPU")
393
        lib.impl("foo", foo_impl, "CUDA")
394

395
        x = torch.randn(3, requires_grad=True, device=device)
396
        # Should not raise
397
        optests.autograd_registration_check(op, (x,), {})
398

399
    def test_autograd_registration_check_compositeimplicitautograd(self, device):
400
        lib = self.lib()
401
        lib.define("foo(Tensor x) -> Tensor")
402
        op = self.ns().foo.default
403

404
        def foo_impl(x):
405
            return x.sin().cos()
406

407
        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
408

409
        x = torch.randn(3, requires_grad=True, device=device)
410
        # Should not raise
411
        optests.autograd_registration_check(op, (x,), {})
412

413
    def test_autograd_registration_check_incorrect_composite(self, device):
414
        lib = self.lib()
415
        lib.define("foo(Tensor x) -> Tensor")
416
        op = self.ns().foo.default
417

418
        def foo_impl(x):
419
            return x.sin().cos()
420

421
        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
422

423
        x = torch.randn(3, requires_grad=True, device=device)
424
        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
425
            optests.autograd_registration_check(op, (x,), {})
426

427
    def test_autograd_registration_check_incorrect(self, device):
428
        lib = self.lib()
429
        lib.define("foo(Tensor x) -> Tensor")
430
        op = self.ns().foo.default
431

432
        class Foo(torch.autograd.Function):
433
            @staticmethod
434
            def forward(ctx, x):
435
                return torch.sin(x)
436

437
            @staticmethod
438
            def backward(ctx, gx):
439
                return gx
440

441
        lib.impl("foo", Foo.apply, "CPU")
442
        lib.impl("foo", Foo.apply, "CUDA")
443

444
        x = torch.randn(3, requires_grad=True, device=device)
445
        with self.assertRaisesRegex(AssertionError, "incorrectly registered"):
446
            optests.autograd_registration_check(op, (x,), {})
447

448
    def test_assert_raises_regex(self, device):
449
        from torch.testing._internal.optests.aot_autograd import assert_raises_regex
450

451
        with assert_raises_regex(RuntimeError, "c"):
452
            raise RuntimeError("abcd")
453
        with assert_raises_regex(RuntimeError, "c.*"):
454
            raise RuntimeError("abcd")
455
        with self.assertRaisesRegex(AssertionError, "instead got"):
456
            with assert_raises_regex(RuntimeError, "c.*"):
457
                raise ValueError("abcd")
458
        with self.assertRaisesRegex(AssertionError, "Expected exception"):
459
            with assert_raises_regex(RuntimeError, "c.*"):
460
                pass
461
        with self.assertRaisesRegex(AssertionError, "to match regex"):
462
            with assert_raises_regex(RuntimeError, "f"):
463
                raise RuntimeError("abcd")
464

465

466
class TestCustomOp(CustomOpTestCaseBase):
467
    test_ns = "_test_custom_op"
468

469
    @requires_compile
470
    def test_functionalize_error(self):
471
        with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib:
472
            lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
473

474
            def foo(x):
475
                return x.sin_()
476

477
            lib.impl("foo", foo, "CompositeExplicitAutograd")
478
            foo_op = self.get_op(f"{self.test_ns}::foo")
479

480
            lib.define("bar(Tensor(a) x) -> Tensor(a)")
481

482
            def bar(x):
483
                return x.view(-1)
484

485
            lib.impl("bar", bar, "CompositeExplicitAutograd")
486
            bar_op = self.get_op(f"{self.test_ns}::bar")
487

488
            msg = r".*We only support functionalizing operators whose outputs do not have alias annotations"
489

490
            x = torch.randn(3)
491

492
            @torch.compile(backend="aot_eager", fullgraph=True)
493
            def f(x):
494
                return foo_op(x)
495

496
            @torch.compile(backend="aot_eager", fullgraph=True)
497
            def g(x):
498
                return bar_op(x)
499

500
            with self.assertRaisesRegex(RuntimeError, msg):
501
                f(x)
502
            with self.assertRaisesRegex(RuntimeError, msg):
503
                g(x)
504

505
    def test_invalid_schemas(self):
506
        # function schmea validation goes through torchgen, so this is just a
507
        # basic test.
508
        with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"):
509
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(")
510

511
    def test_invalid_qualname(self):
512
        with self.assertRaisesRegex(ValueError, "overload"):
513
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()")
514

515
    def test_name_must_match(self):
516
        with self.assertRaisesRegex(ValueError, "to have name"):
517

518
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
519
            def baz(x: Tensor) -> Tensor:
520
                raise NotImplementedError
521

522
    def test_unsupported_schemas(self):
523
        with self.assertRaisesRegex(ValueError, "only supports functional"):
524
            custom_ops.custom_op(
525
                f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)"
526
            )(foo)
527
        with self.assertRaisesRegex(ValueError, "only supports functional"):
528
            custom_ops.custom_op(
529
                f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)"
530
            )(foo)
531
        with self.assertRaisesRegex(ValueError, "only supports functional"):
532
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")(
533
                foo
534
            )
535
        with self.assertRaisesRegex(ValueError, "self"):
536
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")(
537
                foo
538
            )
539

540
    # Tests for the older custom_op API
541
    def test_schema_matches_signature(self):
542
        with self.assertRaisesRegex(ValueError, "signature to match"):
543

544
            @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor")
545
            def blah(x):
546
                pass
547

548
        with self.assertRaisesRegex(ValueError, "signature to match"):
549

550
            @custom_op(
551
                f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor"
552
            )
553
            def blah2(x, y):
554
                pass
555

556
        with self.assertRaisesRegex(ValueError, "signature to match"):
557

558
            @custom_op(
559
                f"{TestCustomOp.test_ns}::blah3",
560
                "(Tensor x, *, Tensor w, Tensor z) -> Tensor",
561
            )
562
            def blah3(x, *, y, z):
563
                pass
564

565
        with self.assertRaisesRegex(ValueError, "signature to match"):
566

567
            @custom_op(
568
                f"{TestCustomOp.test_ns}::blah4",
569
                "(Tensor x, *, Tensor z, Tensor y) -> Tensor",
570
            )
571
            def blah4(x, *, y, z):
572
                pass
573

574
        with self.assertRaisesRegex(ValueError, "not supported"):
575

576
            @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor")
577
            def blah5(*args):
578
                pass
579

580
        with self.assertRaisesRegex(ValueError, "not supported"):
581

582
            @custom_op(
583
                f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor"
584
            )
585
            def blah6(**kwargs):
586
                pass
587

588
        with self.assertRaisesRegex(ValueError, "default arguments"):
589

590
            @custom_op(
591
                f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor"
592
            )
593
            def blah7(x=1, *, y):
594
                pass
595

596
        with self.assertRaisesRegex(ValueError, "default arguments"):
597

598
            @custom_op(
599
                f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor"
600
            )
601
            def blah8(x, *, y=1):
602
                pass
603

604
        # kwonly-arg works
605
        @custom_op(
606
            f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor"
607
        )
608
        def blah9(x, *, y):
609
            pass
610

611
    def test_infer_schema_no_return(self):
612
        with self.assertRaisesRegex(
613
            ValueError, "No return type annotation was provided. Please add one."
614
        ):
615

616
            @torch.library.custom_op("mylib::foo", mutates_args={})
617
            def foo(x: torch.Tensor, y: int):
618
                return x * y
619

620
    def test_infer_schema_supported(self):
621
        def a(x: Tensor) -> Tensor:
622
            return torch.empty([])
623

624
        self.assertExpectedInline(
625
            infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor"""
626
        )
627

628
        def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor:
629
            return torch.empty([])
630

631
        self.assertExpectedInline(
632
            infer_schema(kwonly1, mutates_args=()),
633
            """(Tensor x, *, SymInt y, float z) -> Tensor""",
634
        )
635

636
        def kwonly2(*, y: Tensor) -> Tensor:
637
            return torch.empty([])
638

639
        self.assertExpectedInline(
640
            infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor"""
641
        )
642

643
        def b(
644
            x: Tensor,
645
            y: int,
646
            z: bool,
647
            a: float,
648
            b: torch.dtype,
649
            c: torch.device,
650
            d: torch.types.Number,
651
        ) -> Tuple[Tensor, int, float, bool]:
652
            return torch.empty([]), 1, 0.1, True
653

654
        self.assertExpectedInline(
655
            infer_schema(b, mutates_args=()),
656
            """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""",
657
        )
658

659
        def c(
660
            x: Tensor,
661
            y: Sequence[Tensor],
662
            z: Optional[Tensor],
663
            w: Sequence[Optional[Tensor]],
664
        ) -> List[Tensor]:
665
            return [torch.empty([])]
666

667
        self.assertExpectedInline(
668
            infer_schema(c, mutates_args=()),
669
            """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""",
670
        )
671

672
        def d(x: Tensor) -> Tuple[List[Tensor], Tensor]:
673
            return [torch.empty([])], torch.empty([])
674

675
        self.assertExpectedInline(
676
            infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)"""
677
        )
678

679
        def e() -> Tensor:
680
            return torch.empty([])
681

682
        self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""")
683

684
        def f(x: Tensor) -> None:
685
            pass
686

687
        self.assertExpectedInline(
688
            infer_schema(f, mutates_args=()), """(Tensor x) -> ()"""
689
        )
690

691
        def g(
692
            x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
693
        ) -> None:
694
            pass
695

696
        self.assertExpectedInline(
697
            infer_schema(g, mutates_args=()),
698
            """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""",
699
        )
700

701
        self.assertExpectedInline(
702
            infer_schema(g, mutates_args={"x", "w", "z"}),
703
            """(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
704
        )
705

706
        self.assertExpectedInline(
707
            infer_schema(g, mutates_args="unknown"),
708
            """(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
709
        )
710

711
        def h(
712
            x: Tensor,
713
            a: Optional[int] = None,
714
            b: float = 3.14,
715
            c: bool = True,
716
            d: int = 3,
717
            e: str = "foo",
718
            f: torch.dtype = torch.float,
719
            g: torch.dtype = torch.float32,
720
            h: torch.dtype = torch.int,
721
            i: torch.device = torch.device("cpu:0"),
722
            j: torch.device = "cpu",
723
        ) -> None:
724
            pass
725

726
        self.assertExpectedInline(
727
            infer_schema(h, mutates_args=()),
728
            (
729
                """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
730
                """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
731
            ),
732
        )
733

734
        def foo_impl(x: torch.Tensor) -> torch.Tensor:
735
            return x.sin()
736

737
        schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
738
        self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
739

740
    def test_infer_schema_unsupported(self):
741
        with self.assertRaisesRegex(ValueError, "varargs"):
742

743
            def foo(*args):
744
                raise NotImplementedError
745

746
            infer_schema(foo, mutates_args=())
747

748
        with self.assertRaisesRegex(ValueError, "varkwargs"):
749

750
            def foo(**kwargs):
751
                raise NotImplementedError
752

753
            infer_schema(foo, mutates_args=())
754

755
        with self.assertRaisesRegex(ValueError, "must have a type annotation"):
756

757
            def foo(x):
758
                raise NotImplementedError
759

760
            infer_schema(foo, mutates_args=())
761

762
        with self.assertRaisesRegex(ValueError, "unsupported"):
763

764
            def foo(x: Tensor) -> Tuple[Tensor, ...]:
765
                raise NotImplementedError
766

767
            infer_schema(foo, mutates_args=())
768

769
        with self.assertRaisesRegex(ValueError, "can be mutated"):
770

771
            def foo(x: Tensor, y: int) -> Tensor:
772
                raise NotImplementedError
773

774
            infer_schema(foo, mutates_args={"y"})
775

776
    def _generate_examples(self, typ):
777
        if typ is int:
778
            return [17]
779
        if typ is float:
780
            return [3.14]
781
        if typ is bool:
782
            return [True]
783
        if typ is str:
784
            return ["foo"]
785
        if typ is torch.dtype:
786
            return [torch.float32]
787
        if typ is torch.device:
788
            return [torch.device("cpu")]
789
        if typ == torch.types.Number:
790
            return [2.718]
791
        if typ is torch.Tensor:
792
            return [torch.tensor(3)]
793
        if typ == Optional[torch.types.Number]:
794
            return [None, 2.718]
795
        origin = typing.get_origin(typ)
796
        if origin is Union:
797
            args = typing.get_args(typ)
798
            assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
799
            elt = args[0] if args[1] is type(None) else args[1]
800
            return self._generate_examples(elt) + [None]
801
        if origin is list:
802
            args = typing.get_args(typ)
803
            assert len(args) == 1
804
            elt = args[0]
805
            return [
806
                self._generate_examples(elt),
807
                self._generate_examples(elt),
808
                self._generate_examples(elt),
809
            ]
810
        if origin is collections.abc.Sequence:
811
            args = typing.get_args(typ)
812
            assert len(args) == 1
813
            examples = self._generate_examples(args[0])
814
            return list(itertools.product(examples, examples)) + []
815
        raise NotImplementedError(
816
            f"testrunner cannot generate instanstance of type {typ}"
817
        )
818

819
    def test_supported_return_types_single_return(self):
820
        for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
821
            for example in self._generate_examples(typ):
822
                try:
823

824
                    @custom_ops.custom_op(f"{self.test_ns}::foo")
825
                    def foo(x: Tensor) -> typ:
826
                        raise NotImplementedError
827

828
                    @custom_ops.impl(f"{self.test_ns}::foo")
829
                    def foo_impl(x: Tensor) -> typ:
830
                        return example
831

832
                    op = self.get_op(f"{self.test_ns}::foo")
833
                    result = op(torch.randn([]))
834
                    self.assertEqual(result, example, msg=f"{typ} {example}")
835
                finally:
836
                    custom_ops._destroy(f"{self.test_ns}::foo")
837

838
    def test_supported_return_types_multi_return(self):
839
        for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES:
840
            for example in self._generate_examples(typ):
841
                try:
842

843
                    @custom_ops.custom_op(f"{self.test_ns}::foo")
844
                    def foo(x: Tensor) -> Tuple[typ, typ]:
845
                        raise NotImplementedError
846

847
                    @custom_ops.impl(f"{self.test_ns}::foo")
848
                    def foo_impl(x: Tensor) -> Tuple[typ, typ]:
849
                        return (example, example)
850

851
                    op = self.get_op(f"{self.test_ns}::foo")
852
                    result = op(torch.randn([]))
853
                    expected = (example, example)
854
                    self.assertEqual(result, expected, msg=f"{typ} {example}")
855
                finally:
856
                    custom_ops._destroy(f"{self.test_ns}::foo")
857

858
    def test_supported_param_types(self):
859
        for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES:
860

861
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
862
            def foo(x: Tensor, y: typ) -> Tensor:
863
                raise NotImplementedError
864

865
            yeet = None
866

867
            @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"])
868
            def foo_cpu(x, y):
869
                nonlocal yeet
870
                yeet = y
871
                return x.clone()
872

873
            try:
874
                for example in self._generate_examples(typ):
875
                    op = self.get_op(f"{self.test_ns}::foo")
876
                    op(torch.randn([]), example)
877
                    self.assertEqual(yeet, example, msg=f"{typ} {example}")
878
                    yeet = None
879
            finally:
880
                custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
881

882
    def test_sequences(self):
883
        # Sequence[int] gets automagically turned into int[] in the schema.
884
        # This test checks that we actually do support arbitrary sequence types.
885
        class MySequence(collections.abc.Sequence):
886
            def __init__(self) -> None:
887
                self._container = [1, 2, 3]
888

889
            def __getitem__(self, idx):
890
                return self._container[idx]
891

892
            def __len__(self):
893
                return len(self._container)
894

895
        @custom_ops.custom_op(f"{self.test_ns}::foo")
896
        def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
897
            raise NotImplementedError
898

899
        called = 0
900

901
        @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu")
902
        def foo_cpu(x, sizes):
903
            nonlocal called
904
            called += 1
905
            # Dispatcher will normalize the sequence type into a List
906
            self.assertEqual(sizes, [1, 2, 3])
907
            return x.clone()
908

909
        x = torch.randn([])
910
        seq = MySequence()
911
        op = self.get_op(f"{self.test_ns}::foo")
912
        op(x, seq)
913
        self.assertEqual(called, 1)
914

915
    def test_unsupported_param_types(self):
916
        # Not comprehensive (it doesn't need to be), just a check that our mechanism works
917
        with self.assertRaisesRegex(ValueError, "unsupported type"):
918

919
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
920
            def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
921
                raise NotImplementedError
922

923
            del foo
924

925
        with self.assertRaisesRegex(ValueError, "unsupported type"):
926
            # int[N] in Dispatcher is a bit wild, so we don't try to support it.
927
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
928
            def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
929
                raise NotImplementedError
930

931
            del foo
932

933
        with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"):
934
            # test that we propose a correct and supported type.
935
            @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
936
            def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
937
                raise NotImplementedError
938

939
            del foo
940

941
        with self.assertRaises(ValueError) as cm:
942

943
            @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={})
944
            def foo(x: Tensor, y: Tuple[int, float]) -> Tensor:
945
                raise NotImplementedError
946

947
            del foo
948

949
            self.assertNotIn("example", str(cm.exception), "")
950

951
        with self.assertRaisesRegex(ValueError, "unsupported type"):
952

953
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
954
            def foo(x: Tensor, y: Callable) -> Tensor:
955
                raise NotImplementedError
956

957
            del foo
958

959
    def test_supported_schemas(self):
960
        # All of these should already be tested by PyTorch codegen
961
        # (we share the same mechanism), but here's a sanity check.
962
        schemas = [
963
            "(Tensor x) -> Tensor",
964
            "(Tensor x) -> Tensor y",
965
            "(Tensor[] x) -> Tensor y",
966
            "(Tensor x) -> (Tensor, Tensor)",
967
            "(Tensor x) -> (Tensor y, Tensor z)",
968
            "(Tensor x) -> (Tensor y, Tensor z)",
969
        ]
970
        other_schemas = [
971
            "(Tensor x, Tensor w) -> (Tensor y, Tensor z)",
972
            "(Tensor x, Tensor w) -> (Tensor, Tensor)",
973
            "(Tensor x, Tensor w) -> Tensor",
974
            "(Tensor? x, Tensor w) -> Tensor",
975
            "(Tensor? x, Tensor[] w) -> Tensor",
976
            "(Tensor x, int[] w) -> Tensor",
977
            "(Tensor x, SymInt[] w) -> Tensor",
978
            "(Tensor x, Scalar w) -> Tensor",
979
            "(Tensor x, float w) -> Tensor",
980
            "(Tensor x, float? w) -> Tensor",
981
            "(Tensor x, bool[] w) -> Tensor",
982
        ]
983

984
        for schema in schemas:
985
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema)
986
            custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
987
        for schema in other_schemas:
988
            custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema)
989
            custom_ops._destroy(f"{TestCustomOp.test_ns}::bar")
990

991
    def test_reserved_ns(self):
992
        from torch._custom_op.impl import RESERVED_NS
993

994
        for ns in RESERVED_NS:
995
            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
996
                custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor")
997

998
            with self.assertRaisesRegex(ValueError, "is a reserved namespace"):
999

1000
                @custom_ops.custom_op(f"{ns}::foo2")
1001
                def foo2(x: torch.Tensor) -> torch.Tensor:
1002
                    raise NotImplementedError
1003

1004
    def test_private_ctor(self):
1005
        with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
1006
            CustomOp(None, None, None, None, None)
1007

1008
    def test_lifetime(self):
1009
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1010
        def foo(x: torch.Tensor) -> torch.Tensor:
1011
            raise NotImplementedError
1012

1013
        custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
1014

1015
        # We can't define an op multiple times,
1016
        with self.assertRaisesRegex(RuntimeError, "multiple times"):
1017

1018
            @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1019
            def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1020
                raise NotImplementedError
1021

1022
        # Unless we delete the original op.
1023
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1024

1025
        # Smoke test
1026
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1027
        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1028
            raise NotImplementedError
1029

1030
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1031

1032
    def test_autograd_notimplemented(self):
1033
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1034
        def foo(x: torch.Tensor) -> torch.Tensor:  # noqa: F811
1035
            raise NotImplementedError
1036

1037
        x = torch.randn(3, requires_grad=True)
1038
        op = self.get_op(f"{self.test_ns}::foo")
1039
        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1040
            op(x)
1041
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1042
        del foo
1043

1044
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1045
        def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
1046
            raise NotImplementedError
1047

1048
        x = torch.randn(3, requires_grad=True)
1049
        y = torch.randn(3)
1050
        op = self.get_op(f"{self.test_ns}::foo")
1051
        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1052
            op([y, x])
1053
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1054
        del foo
1055

1056
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1057
        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1058
            raise NotImplementedError
1059

1060
        x = torch.randn(3, requires_grad=True)
1061
        y = torch.randn(3)
1062
        op = self.get_op(f"{self.test_ns}::foo")
1063
        with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"):
1064
            op(y, x)
1065
        custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
1066

1067
    def test_autograd_notimplemented_gradmode(self):
1068
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1069
        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1070
            raise NotImplementedError
1071

1072
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1073
        def foo_impl(x, y):
1074
            return x * y
1075

1076
        x = torch.randn(3, requires_grad=True)
1077
        y = torch.randn(3)
1078
        op = self.get_op(f"{self.test_ns}::foo")
1079
        with torch.no_grad():
1080
            # Shouldn't raise, because we are in no_grad
1081
            op(y, x)
1082

1083
    def test_impl_cpu(self):
1084
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1085
        def foo(x: torch.Tensor) -> torch.Tensor:
1086
            raise NotImplementedError
1087

1088
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1089
        def foo_cpu(x):
1090
            return x.sin()
1091

1092
        x = torch.randn(3)
1093
        op = self.get_op(f"{self.test_ns}::foo")
1094
        result = op(x)
1095
        self.assertEqual(result, foo_cpu(x))
1096

1097
    def test_impl_invalid_devices(self):
1098
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1099
        def foo(x: torch.Tensor) -> torch.Tensor:
1100
            raise NotImplementedError
1101

1102
        def foo_impl(x):
1103
            return x.sin()
1104

1105
        from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
1106

1107
        for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
1108
            # Smoke test: should not raise error
1109
            custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)(
1110
                foo_impl
1111
            )
1112

1113
        # Not supported by this API: we can either support them in the future
1114
        # or provide some other CustomOp.def_* function. This depends on how
1115
        # common the use cases are.
1116
        for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]:
1117
            with self.assertRaisesRegex(ValueError, "we only support device_type"):
1118
                custom_ops.impl(
1119
                    f"{TestCustomOp.test_ns}::foo", device_types=invalid_type
1120
                )(foo_impl)
1121

1122
    def test_backward_partially_registered(self):
1123
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1124
        def foo(x: torch.Tensor) -> torch.Tensor:
1125
            raise NotImplementedError
1126

1127
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1128
        def foo_impl(x):
1129
            return x.sin()
1130

1131
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1132
        def foo_backward(ctx, saved, grad):
1133
            return grad * saved.cos()
1134

1135
        x = torch.randn([], requires_grad=True)
1136
        op = self.get_op(f"{self.test_ns}::foo")
1137
        with self.assertRaisesRegex(
1138
            RuntimeError, "unable to find a 'save_for_backward'"
1139
        ):
1140
            y = op(x)
1141
            y.backward()
1142

1143
    def test_save_for_backward_inputs_are_namedtuple(self):
1144
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1145
        def foo(x: torch.Tensor) -> torch.Tensor:
1146
            raise NotImplementedError
1147

1148
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1149
        def foo_impl(x):
1150
            return x.sin()
1151

1152
        hit = 0
1153

1154
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1155
        def foo_save_for_backward(inputs, output):
1156
            nonlocal hit
1157
            hit += 1
1158
            self.assertTrue(isinstance(inputs, tuple))
1159
            self.assertEqual(list(inputs._asdict().keys()), ["x"])
1160
            return inputs.x
1161

1162
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1163
        def foo_backward(ctx, saved, grad):
1164
            return {"x": grad * saved.cos()}
1165

1166
        x = torch.randn([], requires_grad=True)
1167
        op = self.get_op(f"{self.test_ns}::foo")
1168
        y = op(x)
1169
        self.assertEqual(hit, 1)
1170
        y.backward()
1171
        self.assertEqual(hit, 1)
1172

1173
    def test_backward_returns_dict(self):
1174
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1175
        def foo(x: torch.Tensor) -> torch.Tensor:
1176
            raise NotImplementedError
1177

1178
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1179
        def foo_impl(x):
1180
            return x.sin()
1181

1182
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1183
        def foo_save_for_backward(inputs, output):
1184
            return inputs.x
1185

1186
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1187
        def foo_backward(ctx, saved, grad):
1188
            return grad * saved.cos()
1189

1190
        x = torch.randn([], requires_grad=True)
1191
        op = self.get_op(f"{self.test_ns}::foo")
1192
        y = op(x)
1193
        with self.assertRaisesRegex(RuntimeError, "to be a dict"):
1194
            y.backward()
1195

1196
    def test_backward_dict_invalid_keys(self):
1197
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1198
        def foo(x: torch.Tensor) -> torch.Tensor:
1199
            raise NotImplementedError
1200

1201
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1202
        def foo_impl(x):
1203
            return x.sin()
1204

1205
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1206
        def foo_save_for_backward(inputs, output):
1207
            return inputs.x
1208

1209
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1210
        def foo_backward(ctx, saved, grad):
1211
            return {"x": grad * saved.cos(), "y": None}
1212

1213
        x = torch.randn([], requires_grad=True)
1214
        op = self.get_op(f"{self.test_ns}::foo")
1215
        y = op(x)
1216
        with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"):
1217
            y.backward()
1218

1219
    def test_backward_dict_grad_for_nontensor(self):
1220
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1221
        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1222
            raise NotImplementedError
1223

1224
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1225
        def foo_impl(x, dim):
1226
            return x.sin()
1227

1228
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1229
        def foo_save_for_backward(inputs, output):
1230
            return inputs.x
1231

1232
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1233
        def foo_backward(ctx, saved, grad):
1234
            return {"x": grad * saved.cos(), "dim": None}
1235

1236
        x = torch.randn([], requires_grad=True)
1237
        op = self.get_op(f"{self.test_ns}::foo")
1238
        y = op(x, 32)
1239
        with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"):
1240
            y.backward()
1241

1242
    def test_backward_dict_requires_keys_for_input_tensors(self):
1243
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1244
        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1245
            raise NotImplementedError
1246

1247
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1248
        def foo_impl(x, y):
1249
            return x.sin()
1250

1251
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1252
        def foo_save_for_backward(inputs, output):
1253
            return inputs.x
1254

1255
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1256
        def foo_backward(ctx, saved, grad):
1257
            return {"x": grad * saved.cos()}
1258

1259
        x = torch.randn([], requires_grad=True)
1260
        op = self.get_op(f"{self.test_ns}::foo")
1261
        y = op(x, x)
1262
        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1263
            y.backward()
1264

1265
    def test_backward_dict_requires_keys_for_input_optional_tensors(self):
1266
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1267
        def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
1268
            raise NotImplementedError
1269

1270
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1271
        def foo_impl(x, y):
1272
            return x.sin()
1273

1274
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1275
        def foo_save_for_backward(inputs, output):
1276
            return inputs.x
1277

1278
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1279
        def foo_backward(ctx, saved, grad):
1280
            return {"x": grad * saved.cos()}
1281

1282
        x = torch.randn([], requires_grad=True)
1283
        op = self.get_op(f"{self.test_ns}::foo")
1284
        y = op(x, None)
1285
        with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"):
1286
            y.backward()
1287

1288
    def test_backward_grads_are_tensor_or_none(self):
1289
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1290
        def foo(x: torch.Tensor) -> torch.Tensor:
1291
            raise NotImplementedError
1292

1293
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1294
        def foo_impl(x):
1295
            return x.sin()
1296

1297
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1298
        def foo_save_for_backward(inputs, output):
1299
            return inputs.x
1300

1301
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1302
        def foo_backward(ctx, saved, grad):
1303
            return {"x": (grad * saved.cos(),)}
1304

1305
        x = torch.randn([], requires_grad=True)
1306
        op = self.get_op(f"{self.test_ns}::foo")
1307
        y = op(x)
1308
        with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"):
1309
            y.backward()
1310

1311
    def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
1312
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1313
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1314
            raise NotImplementedError
1315

1316
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1317
        def foo_impl(xs):
1318
            return xs[0].sin()
1319

1320
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1321
        def foo_save_for_backward(inputs, output):
1322
            return inputs.xs[0]
1323

1324
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1325
        def foo_backward(ctx, saved, grad):
1326
            return {"xs": [grad * saved.cos(), None]}
1327

1328
        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1329
        op = self.get_op(f"{self.test_ns}::foo")
1330
        y = op(xs)
1331
        with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"):
1332
            y.backward()
1333

1334
    def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
1335
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1336
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1337
            raise NotImplementedError
1338

1339
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1340
        def foo_impl(xs):
1341
            return xs[0].sin()
1342

1343
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1344
        def foo_save_for_backward(inputs, output):
1345
            return inputs.xs[0]
1346

1347
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1348
        def foo_backward(ctx, saved, grad):
1349
            return {"xs": [grad * saved.cos(), None, (None,)]}
1350

1351
        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1352
        op = self.get_op(f"{self.test_ns}::foo")
1353
        y = op(xs)
1354
        with self.assertRaisesRegex(RuntimeError, "None or Tensor"):
1355
            y.backward()
1356

1357
    def test_backward_tensorlist_input_requires_list_grads(self):
1358
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1359
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1360
            raise NotImplementedError
1361

1362
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1363
        def foo_impl(xs):
1364
            return xs[0].sin()
1365

1366
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1367
        def foo_save_for_backward(inputs, output):
1368
            return inputs.xs[0]
1369

1370
        @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo")
1371
        def foo_backward(ctx, saved, grad):
1372
            return {"xs": None}
1373

1374
        xs = [torch.randn([], requires_grad=True) for _ in range(3)]
1375
        op = self.get_op(f"{self.test_ns}::foo")
1376
        y = op(xs)
1377
        with self.assertRaisesRegex(RuntimeError, "list of gradients"):
1378
            y.backward()
1379

1380
    def test_backward_output_differentiability_type(self):
1381
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1382
        def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
1383
            raise NotImplementedError
1384

1385
        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1386

1387
            @custom_ops.impl_backward(
1388
                f"{TestCustomOp.test_ns}::foo", output_differentiability=True
1389
            )
1390
            def foo_backward(ctx, saved, grad):
1391
                return {"xs": None}
1392

1393
    def test_backward_output_differentiability_numel(self):
1394
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1395
        def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
1396
            raise NotImplementedError
1397

1398
        with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
1399

1400
            @custom_ops.impl_backward(
1401
                f"{TestCustomOp.test_ns}::foo", output_differentiability=[True]
1402
            )
1403
            def foo_backward(ctx, saved, grad):
1404
                return {"xs": None}
1405

1406
    def test_backward_output_differentiability_tensorlist(self):
1407
        @custom_ops.custom_op(f"{self.test_ns}::foo")
1408
        def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
1409
            raise NotImplementedError
1410

1411
        @custom_ops.impl(f"{self.test_ns}::foo")
1412
        def foo_impl(x):
1413
            return [x.clone(), x.clone()], x.clone()
1414

1415
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1416
        def foo_save_for_backward(inputs, output):
1417
            return []
1418

1419
        @custom_ops.impl_backward(
1420
            f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True]
1421
        )
1422
        def foo_backward(ctx, saved, grad_lst, grad):
1423
            return {"x": grad}
1424

1425
        op = self.get_op(f"{self.test_ns}::foo")
1426
        x = torch.randn(3, requires_grad=True)
1427
        [a, b], c = op(x)
1428
        self.assertFalse(a.requires_grad)
1429
        self.assertFalse(b.requires_grad)
1430
        self.assertTrue(c.requires_grad)
1431

1432
    def test_backward_output_differentiability_non_tensor(self):
1433
        @custom_ops.custom_op(f"{self.test_ns}::foo")
1434
        def foo(x: Tensor) -> Tuple[Tensor, int]:
1435
            raise NotImplementedError
1436

1437
        @custom_ops.impl(f"{self.test_ns}::foo")
1438
        def foo_impl(x):
1439
            return x.clone(), 3
1440

1441
        @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo")
1442
        def foo_save_for_backward(inputs, output):
1443
            return []
1444

1445
        @custom_ops.impl_backward(
1446
            f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True]
1447
        )
1448
        def foo_backward(ctx, saved, grad0, grad1):
1449
            return {"x": grad0}
1450

1451
        op = self.get_op(f"{self.test_ns}::foo")
1452
        x = torch.randn(3, requires_grad=True)
1453
        with self.assertRaisesRegex(RuntimeError, "is not a Tensor"):
1454
            op(x)
1455

1456
    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1457
    def test_impl_separate(self):
1458
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1459
        def foo(x: torch.Tensor) -> torch.Tensor:
1460
            raise NotImplementedError
1461

1462
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
1463
        def foo_cpu(x):
1464
            return x.sin()
1465

1466
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda")
1467
        def foo_cuda(x):
1468
            return x.cos()
1469

1470
        x = torch.randn(3)
1471
        op = self.get_op(f"{self.test_ns}::foo")
1472
        result = op(x)
1473
        self.assertEqual(result, foo_cpu(x))
1474

1475
        x_cuda = x.cuda()
1476
        op = self.get_op(f"{self.test_ns}::foo")
1477
        result = op(x_cuda)
1478
        self.assertEqual(result, foo_cuda(x_cuda))
1479

1480
    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
1481
    def test_impl_multiple(self):
1482
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1483
        def foo(x: torch.Tensor) -> torch.Tensor:
1484
            raise NotImplementedError
1485

1486
        @custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
1487
        def foo_impl(x):
1488
            return x.cos()
1489

1490
        op = self.get_op(f"{self.test_ns}::foo")
1491
        x = torch.randn(3)
1492
        result = op(x)
1493
        self.assertEqual(result, foo_impl(x))
1494

1495
        x_cuda = x.cuda()
1496
        result = op(x_cuda)
1497
        self.assertEqual(result, foo_impl(x_cuda))
1498

1499
    def test_impl_abstract_overload(self):
1500
        lib = self.lib()
1501
        lib.define("sin.blah(Tensor x) -> Tensor")
1502

1503
        torch.library.impl_abstract(
1504
            f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
1505
        )
1506

1507
        op = self.ns().sin.blah
1508
        x = torch.randn(3, device="meta")
1509
        op(x)
1510

1511
    def test_impl_meta(self):
1512
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1513
        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1514
            raise NotImplementedError
1515

1516
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1517
        def foo_meta(x, dim):
1518
            output_shape = list(x.shape)
1519
            del output_shape[dim]
1520
            return x.new_empty(output_shape)
1521

1522
        x = torch.randn(2, 3, device="meta")
1523
        op = self.get_op(f"{self.test_ns}::foo")
1524
        result = op(x, 1)
1525
        self.assertEqual(result.shape, foo_meta(x, 1).shape)
1526

1527
    def test_duplicate_impl(self):
1528
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1529
        def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
1530
            raise NotImplementedError
1531

1532
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1533
        def foo_meta(x, dim):
1534
            output_shape = list(x.shape)
1535
            del output_shape[dim]
1536
            return x.new_empty(output_shape)
1537

1538
        with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
1539

1540
            @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1541
            def foo_meta2(x, dim):
1542
                output_shape = list(x.shape)
1543
                del output_shape[dim]
1544
                return x.new_empty(output_shape)
1545

1546
    def test_new_data_dependent_symint(self):
1547
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1548
        def foo(x: torch.Tensor) -> torch.Tensor:
1549
            raise NotImplementedError
1550

1551
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1552
        def foo_meta(x):
1553
            ctx = torch.library.get_ctx()
1554
            r = ctx.new_dynamic_size(min=1)
1555
            with self.assertRaisesRegex(ValueError, "greater than or equal to 0"):
1556
                ctx.new_dynamic_size(min=-1)
1557
            with self.assertRaisesRegex(ValueError, "SymInt"):
1558
                ctx.new_dynamic_size(max=x.numel())
1559
            # NB: You must return dynamic sizes!
1560
            return x.new_empty(r)
1561

1562
        x = torch.randn(2, 3, device="cpu")
1563
        op = self.get_op(f"{self.test_ns}::foo")
1564
        make_fx(op, tracing_mode="symbolic")(x)
1565

1566
    def test_meta_for_data_dependent_shape_operation(self):
1567
        x = torch.randn(10, device="meta")
1568
        with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"):
1569
            numpy_nonzero(x)
1570

1571
    def test_basic_make_fx(self):
1572
        # More serious tests are in our CustomOp opinfo db,
1573
        # this one is just a sanity check.
1574
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1575
        def foo(x: torch.Tensor) -> torch.Tensor:
1576
            raise NotImplementedError
1577

1578
        @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
1579
        def foo_meta(x):
1580
            return x.sum()
1581

1582
        x = torch.randn(3)
1583
        op = self.get_op(f"{self.test_ns}::foo")
1584
        gm = make_fx(op, tracing_mode="symbolic")(x)
1585
        self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code)
1586

1587
    def test_not_implemented_error(self):
1588
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
1589
        def foo(x: torch.Tensor) -> torch.Tensor:
1590
            raise NotImplementedError
1591

1592
        x = torch.randn(3)
1593
        op = self.get_op(f"{self.test_ns}::foo")
1594
        with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
1595
            op(x)
1596

1597
        x = torch.randn(3, device="meta")
1598
        with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"):
1599
            op(x)
1600

1601
        @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
1602
        def bar(sizes: Sequence[int]) -> torch.Tensor:
1603
            raise NotImplementedError
1604

1605
        op = self.get_op(f"{self.test_ns}::bar")
1606
        with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
1607
            op((1, 2, 3))
1608

1609
    def test_data_dependent_basic(self):
1610
        x = torch.randn(5, 5)
1611
        gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x)
1612
        self.assertTrue("nonzero" in gm.code)
1613

1614
    def test_data_dependent_fake_tracing(self):
1615
        x = torch.randn(5, 5)
1616
        # We've updated to attempt to use unbacked symints even for fake
1617
        # tracing
1618
        make_fx(numpy_nonzero, tracing_mode="fake")(x)
1619

1620
    def test_symints(self):
1621
        def f(x):
1622
            return torch.ops._torch_testing.numpy_view_copy(x, x.shape)
1623

1624
        x = torch.randn(2, 3, 4)
1625
        gm = make_fx(f, tracing_mode="symbolic")(x)
1626
        result = gm(x)
1627
        self.assertEqual(result, f(x))
1628
        self.assertExpectedInline(
1629
            gm.code.strip(),
1630
            """\
1631
def forward(self, x_1):
1632
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1633
    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
1634
    sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
1635
    numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
1636
    return numpy_view_copy""",  # noqa: B950
1637
        )
1638

1639
    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
1640
    def test_data_dependent_compile(self):
1641
        import torch._dynamo.testing
1642
        from torch._dynamo.utils import counters
1643

1644
        counters.clear()
1645
        cnt = torch._dynamo.testing.CompileCounter()
1646

1647
        @torch.compile(backend=cnt)
1648
        def f(x):
1649
            return numpy_nonzero(x.clone()).clone()
1650

1651
        f(torch.randn(10))
1652

1653
        self.assertEqual(len(counters["graph_break"]), 1)
1654
        self.assertEqual(next(iter(counters["graph_break"].values())), 1)
1655
        self.assertExpectedInline(
1656
            next(iter(counters["graph_break"].keys())).replace(";", "\n"),
1657
            """\
1658
dynamic shape operator: _torch_testing.numpy_nonzero.default
1659
 to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
1660
        )
1661

1662
    # pre-existing problem: torch.compile(dynamic=True) will, by default,
1663
    # graph break on data-dependent operations. Eventually we'll make it so
1664
    # that it never graph breaks on data-dependent operations.
1665
    @unittest.expectedFailure
1666
    def test_data_dependent_nms_dynamic_compile(self):
1667
        import torch._dynamo.testing
1668
        from torch._dynamo.utils import counters
1669

1670
        counters.clear()
1671
        cnt = torch._dynamo.testing.CompileCounter()
1672

1673
        @torch.compile(backend=cnt, dynamic=True)
1674
        def f(x, s, i):
1675
            return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone()
1676

1677
        f(torch.randn(20, 4), torch.randn(20), 0.1)
1678

1679
        self.assertEqual(len(counters["graph_break"]), 0)
1680

1681
    def test_impl_on_existing_op(self):
1682
        lib = self.lib()
1683
        lib.define("foo(Tensor x) -> Tensor")
1684
        qualname = f"{self.test_ns}::foo"
1685

1686
        @torch._custom_ops.impl(qualname)
1687
        def foo_impl(x):
1688
            return x.sin()
1689

1690
        op = self.get_op(qualname)
1691
        x = torch.randn(3)
1692
        result = op(x)
1693
        self.assertEqual(result, x.sin())
1694

1695
    @parametrize(
1696
        "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"]
1697
    )
1698
    def test_impl_on_existing_op_with_cpu_registration(self, key):
1699
        lib = self.lib()
1700
        lib.define("foo(Tensor x) -> Tensor")
1701
        qualname = f"{self.test_ns}::foo"
1702

1703
        def foo_impl(x):
1704
            return x.sin()
1705

1706
        lib.impl("foo", foo_impl, key)
1707
        op = self.get_op(qualname)
1708

1709
        with self.assertRaisesRegex(RuntimeError, "already has an implementation"):
1710
            custom_ops.impl(qualname, func=foo_impl)
1711

1712
    def test_abstract_impl_on_existing_op(self):
1713
        lib = self.lib()
1714
        lib.define("foo(Tensor x) -> Tensor")
1715
        qualname = f"{self.test_ns}::foo"
1716

1717
        @torch.library.impl_abstract(qualname, lib=self.lib())
1718
        def foo_impl(x):
1719
            return x.sin()
1720

1721
        op = self.get_op(qualname)
1722
        with torch._subclasses.FakeTensorMode():
1723
            x = torch.randn(3)
1724
            result = op(x)
1725
            self.assertEqual(result.shape, x.shape)
1726
            self.assertEqual(result.stride(), x.stride())
1727

1728
    def test_abstract_impl_on_existing_op_with_meta(self):
1729
        lib = self.lib()
1730
        lib.define("foo(Tensor x) -> Tensor")
1731
        qualname = f"{self.test_ns}::foo"
1732

1733
        def foo_impl(x):
1734
            return x.sin()
1735

1736
        lib.impl("foo", foo_impl, "Meta")
1737
        op = self.get_op(qualname)
1738

1739
        with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
1740
            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1741

1742
    def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
1743
        lib = self.lib()
1744
        lib.define("foo(Tensor x) -> Tensor")
1745
        qualname = f"{self.test_ns}::foo"
1746

1747
        def foo_impl(x):
1748
            return x.sin()
1749

1750
        lib.impl("foo", foo_impl, "CompositeImplicitAutograd")
1751
        op = self.get_op(qualname)
1752

1753
        with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
1754
            torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
1755

1756
    def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
1757
        lib = self.lib()
1758
        lib.define("foo(Tensor x) -> Tensor")
1759
        qualname = f"{self.test_ns}::foo"
1760

1761
        def foo_impl(x):
1762
            return x.sin()
1763

1764
        lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
1765
        op = self.get_op(qualname)
1766

1767
        torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
1768
        with torch._subclasses.FakeTensorMode():
1769
            x = torch.randn(10)
1770
            result = op(x)
1771
            self.assertEqual(result.shape, ())
1772

1773
    def _test_backward_impl_raises(self, qualname, err_regex):
1774
        with self.assertRaisesRegex(RuntimeError, err_regex):
1775

1776
            @custom_ops.impl_save_for_backward(qualname)
1777
            def foo2(x):
1778
                return
1779

1780
        with self.assertRaisesRegex(RuntimeError, err_regex):
1781

1782
            @custom_ops.impl_backward(qualname)
1783
            def foo3(x):
1784
                return
1785

1786
    def test_backward_impl_on_existing_op_incorrect_schema_views(self):
1787
        lib = self.lib()
1788
        lib.define("foo(Tensor(a) x) -> Tensor(a)")
1789
        qualname = f"{self.test_ns}::foo"
1790
        self._test_backward_impl_raises(qualname, "operator that returns views")
1791

1792
    def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
1793
        lib = self.lib()
1794
        lib.define("foo(Tensor(a!) x) -> Tensor")
1795
        qualname = f"{self.test_ns}::foo"
1796
        self._test_backward_impl_raises(qualname, "non-functional")
1797

1798
    def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
1799
        lib = self.lib()
1800
        lib.define("foo(Tensor x) -> ()")
1801
        qualname = f"{self.test_ns}::foo"
1802
        self._test_backward_impl_raises(qualname, "no returns")
1803

1804
    def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
1805
        lib = self.lib()
1806
        lib.define("foo(Tensor x) -> Tensor")
1807
        qualname = f"{self.test_ns}::foo"
1808
        lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
1809
        self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
1810

1811
    @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
1812
    def test_backward_impl_on_existing_op_with_key(self, key):
1813
        lib = self.lib()
1814
        lib.define("foo(Tensor x) -> Tensor")
1815
        qualname = f"{self.test_ns}::foo"
1816
        lib.impl("foo", lambda x: x.sin().cos(), key)
1817
        self._test_backward_impl_raises(qualname, key)
1818

1819
    def test_is_functional_schema(self):
1820
        tests = {
1821
            "foo(Tensor x) -> Tensor": True,
1822
            "foo(Tensor(a) x) -> Tensor": True,
1823
            "foo(Tensor(a!) x) -> Tensor": False,
1824
            "foo(Tensor(a) x) -> Tensor(a)": False,
1825
            "foo(Tensor x) -> ()": False,
1826
        }
1827
        for schema_str, expected in tests.items():
1828
            res = torch._library.utils.is_functional_schema(schema_str)
1829
            self.assertEqual(res, expected)
1830

1831
            from torchgen.model import FunctionSchema
1832

1833
            schema = FunctionSchema.parse(schema_str)
1834
            res = torch._library.utils.is_functional_schema(schema)
1835
            self.assertEqual(res, expected)
1836

1837
            schema = torch._C.parse_schema(schema_str)
1838
            res = torch._library.utils.is_functional_schema(schema)
1839
            self.assertEqual(res, expected)
1840

1841
    def test_incorrect_schema_types(self):
1842
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1843
            with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1844
                lib.define("foo12(Tensor a) -> asdfasdf")
1845
            with self.assertRaisesRegex(RuntimeError, "unknown type specifier"):
1846
                lib.define("foo12(asdf a) -> Tensor")
1847
            with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"):
1848
                lib.define("foo12(int64_t a) -> Tensor")
1849
            with self.assertRaisesRegex(RuntimeError, "Use `float`"):
1850
                lib.define("foo12(double a) -> Tensor")
1851

1852
    def test_is_tensorlist_like_type(self):
1853
        tensorlists = [
1854
            # Tensor[]
1855
            torch.ops.aten.where.default._schema.returns[0].type,
1856
            # Tensor?[]
1857
            torch.ops.aten.index.Tensor._schema.arguments[1].type,
1858
            # Tensor[]?
1859
            torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type,
1860
            # Tensor?[]?
1861
            torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type,
1862
        ]
1863
        non_tensorlists = [
1864
            # Tensor
1865
            torch.ops.aten.sin.default._schema.arguments[0].type,
1866
            # IntList
1867
            torch.ops.aten.sum.dim_IntList._schema.arguments[1].type,
1868
        ]
1869
        for a in tensorlists:
1870
            self.assertTrue(torch._library.utils.is_tensorlist_like_type(a))
1871
        for a in non_tensorlists:
1872
            self.assertFalse(torch._library.utils.is_tensorlist_like_type(a))
1873

1874
    def test_backward_impl_on_existing_op(self):
1875
        lib = self.lib()
1876
        lib.define("foo(Tensor x) -> Tensor")
1877
        qualname = f"{self.test_ns}::foo"
1878

1879
        @custom_ops.impl(qualname)
1880
        def foo_impl(x):
1881
            with torch.no_grad():
1882
                return x.sin()
1883

1884
        @custom_ops.impl_save_for_backward(qualname)
1885
        def foo_save_for_backward(inputs, output):
1886
            return inputs.x
1887

1888
        @custom_ops.impl_backward(qualname)
1889
        def foo_backward(ctx, saved, grad_out):
1890
            return {"x": grad_out * saved.cos()}
1891

1892
        op = self.get_op(qualname)
1893
        x = torch.randn([], requires_grad=True)
1894
        y = op(x)
1895
        (gx,) = torch.autograd.grad(y, x)
1896
        self.assertEqual(gx, x.cos())
1897

1898
    @parametrize(
1899
        "tags",
1900
        [
1901
            subtest(torch.Tag.pointwise, "single"),
1902
            subtest((torch.Tag.pointwise,), "tuple"),
1903
            subtest([torch.Tag.pointwise], "list"),
1904
        ],
1905
    )
1906
    def test_define_with_tags(self, tags):
1907
        lib = self.lib()
1908
        tags = (torch.Tag.pointwise,)
1909
        torch.library.define(
1910
            f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags
1911
        )
1912
        actual = self.ns().foo.default.tags
1913
        self.assertTrue(isinstance(actual, list))
1914
        self.assertEqual(actual, list(tags))
1915

1916
    def test_builtin_aten_ops_are_pt2_compliant(self):
1917
        for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
1918
            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1919

1920
    def test_builtin_torchscript_ops(self):
1921
        for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]:
1922
            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1923

1924
    def test_autogen_aten_ops_are_pt2_compliant(self):
1925
        for op in [torch.ops.aten.fill.Tensor_out]:
1926
            self.assertIn(torch.Tag.generated, op.tags)
1927
            self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
1928

1929
    def test_resolve_packet(self):
1930
        x = torch.randn(3)
1931
        result = torch._C._jit_resolve_packet("aten::sum", x)
1932
        self.assertEqual(result, "default")
1933

1934
        result = torch._C._jit_resolve_packet("aten::sum", x, dim=1)
1935
        self.assertEqual(result, "dim_IntList")
1936

1937
        with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1938
            result = torch._C._jit_resolve_packet("aten::sum", x, x, x)
1939

1940
    def test_define_bad_schema(self):
1941
        lib = self.lib()
1942
        with self.assertRaisesRegex(ValueError, "expected schema to look like"):
1943
            torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor")
1944

1945
    def test_define_and_impl(self):
1946
        lib = self.lib()
1947
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1948

1949
        @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib)
1950
        def f(x):
1951
            return torch.from_numpy(np.sin(x.numpy()))
1952

1953
        x = torch.randn(3)
1954
        y = self.ns().foo(x)
1955
        assert torch.allclose(y, x.sin())
1956

1957
    def test_define_validation(self):
1958
        with self.assertRaisesRegex(ValueError, "namespace"):
1959
            torch.library.define("foo", "(Tensor x) -> Tensor")
1960

1961
    def test_legacy_define(self):
1962
        lib = self.lib()
1963

1964
        @torch.library.define(lib, "foo(Tensor x) -> Tensor")
1965
        def f(x):
1966
            return torch.from_numpy(np.sin(x.numpy()))
1967

1968
        x = torch.randn(3)
1969
        y = self.ns().foo(x)
1970
        assert torch.allclose(y, x.sin())
1971

1972
    def test_impl_function(self):
1973
        lib = self.lib()
1974
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1975

1976
        def f(x):
1977
            return torch.from_numpy(np.sin(x.numpy()))
1978

1979
        torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib)
1980
        x = torch.randn(3)
1981
        y = self.ns().foo(x)
1982
        assert torch.allclose(y, x.sin())
1983

1984
    def test_legacy_impl(self):
1985
        lib = self.lib()
1986
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
1987

1988
        @torch.library.impl(lib, "foo", "CPU")
1989
        def f(x):
1990
            return torch.from_numpy(np.sin(x.numpy()))
1991

1992
        x = torch.randn(3)
1993
        y = self.ns().foo(x)
1994
        assert torch.allclose(y, x.sin())
1995

1996
    def test_defined_in_python(self):
1997
        self.assertFalse(torch.ops.aten.sin.default._defined_in_python)
1998
        self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python)
1999

2000
        lib = self.lib()
2001
        torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2002
        ns = self.ns()
2003
        self.assertTrue(ns.foo.default._defined_in_python)
2004

2005
        torch.library.define(
2006
            "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib
2007
        )
2008
        self.assertTrue(ns.bar.overload._defined_in_python)
2009

2010
    def _test_impl_device(self, name, types, device):
2011
        lib = self.lib()
2012
        torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib)
2013

2014
        @torch.library.impl(f"{self.test_ns}::{name}", types)
2015
        def f(x):
2016
            x_np = x.cpu().numpy()
2017
            y = torch.from_numpy(np.sin(x_np))
2018
            return y.to(device=x.device)
2019

2020
        x = torch.randn(3, device=device)
2021
        y = getattr(self.ns(), name)(x)
2022
        assert torch.allclose(y, x.sin())
2023

2024
    def test_impl_device_cpu(self):
2025
        self._test_impl_device("foo1", "default", "cpu")
2026
        self._test_impl_device("foo2", ["cpu"], "cpu")
2027
        self._test_impl_device("foo3", ["cpu", "cuda"], "cpu")
2028

2029
    @unittest.skipIf(not TEST_CUDA, "requires cuda")
2030
    def test_impl_device_cuda(self):
2031
        self._test_impl_device("foo4", "default", "cuda")
2032
        self._test_impl_device("foo5", ["cuda"], "cuda")
2033
        self._test_impl_device("foo6", ["cpu", "cuda"], "cuda")
2034

2035
    def test_impl_device_function(self):
2036
        lib = self.lib()
2037
        torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib)
2038

2039
        def f(x):
2040
            x_np = x.cpu().numpy()
2041
            y = torch.from_numpy(np.sin(x_np))
2042
            return y.to(device=x.device)
2043

2044
        torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib)
2045
        x = torch.randn(3)
2046
        y = self.ns().foo(x)
2047
        assert torch.allclose(y, x.sin())
2048

2049
    def test_impl_device_invalid(self):
2050
        with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
2051
            torch.library.impl("blah::blah", "somethingsomething")
2052

2053
    def test_autograd_function_backed_op(self):
2054
        cpp_source = """
2055
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
2056
  static constexpr bool is_traceable = true;
2057

2058
  static torch::Tensor forward(
2059
      torch::autograd::AutogradContext* ctx,
2060
      const torch::Tensor& x) {
2061
    return x;
2062
  }
2063

2064
  static torch::autograd::variable_list backward(
2065
      torch::autograd::AutogradContext *ctx,
2066
      torch::autograd::variable_list grad_output) {
2067
    return grad_output;
2068
  }
2069
};
2070

2071
torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
2072
  return CustomOpAutogradFunction::apply(x);
2073
}
2074

2075
TORCH_LIBRARY(mylib, m) {
2076
    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
2077
}
2078
        """
2079

2080
        module = torch.utils.cpp_extension.load_inline(
2081
            name="mylib",
2082
            cpp_sources=cpp_source,
2083
            functions="custom_op_backed_by_autograd_fn",
2084
            verbose=True,
2085
        )
2086

2087
        x = torch.ones(2, 2, requires_grad=True)
2088
        temp = x.clone().detach()
2089
        out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
2090
        loss = out.sum()
2091
        loss.backward()
2092
        self.assertEqual(x.grad, temp)
2093

2094

2095
def op_with_incorrect_schema(testcase, name):
2096
    lib = testcase.lib()
2097
    lib.define(f"{name}(Tensor x) -> Tensor")
2098
    qualname = f"{testcase.test_ns}::{name}"
2099
    lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd")
2100
    return testcase.get_op(qualname)
2101

2102

2103
class MiniOpTest(CustomOpTestCaseBase):
2104
    test_ns = "mini_op_test"
2105

2106
    def _init_op_delayed_backward_error(self):
2107
        name = "delayed_error"
2108
        qualname = f"{self.test_ns}::{name}"
2109
        lib = self.lib()
2110
        lib.define(f"{name}(Tensor x) -> Tensor")
2111
        lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd")
2112
        op = self.get_op(qualname)
2113

2114
        class Op(torch.autograd.Function):
2115
            @staticmethod
2116
            def forward(ctx, x):
2117
                with torch._C._AutoDispatchBelowAutograd():
2118
                    return op(x)
2119

2120
            @staticmethod
2121
            def backward(ctx, grad):
2122
                raise NotImplementedError
2123

2124
        def autograd_impl(x):
2125
            return Op.apply(x)
2126

2127
        lib.impl(name, autograd_impl, "Autograd")
2128
        return op
2129

2130
    def _init_op_with_no_abstract_impl(self):
2131
        name = "no_abstract"
2132
        qualname = f"{self.test_ns}::{name}"
2133
        lib = self.lib()
2134
        lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,))
2135
        lib.impl(name, lambda x: x.clone(), "CPU")
2136
        return torch._library.utils.lookup_op(qualname)
2137

2138
    def setUp(self):
2139
        super().setUp()
2140
        self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl()
2141
        self._op_delayed_backward_error = self._init_op_delayed_backward_error()
2142

2143
    @optests.dontGenerateOpCheckTests("Testing this API")
2144
    def test_dont_generate(self):
2145
        op = op_with_incorrect_schema(self, "incorrect_schema")
2146
        x = torch.randn(3)
2147
        op(x)
2148

2149
    def test_mm(self):
2150
        x = torch.randn(2, 3, requires_grad=True)
2151
        y = torch.randn(3, 5)
2152
        result = torch.ops.aten.mm.default(x, y)
2153
        self.assertEqual(result, x @ y)
2154

2155
    def test_mm_meta(self):
2156
        x = torch.randn(2, 3, requires_grad=True, device="meta")
2157
        y = torch.randn(3, 5, device="meta")
2158
        result = torch.ops.aten.mm.default(x, y)
2159
        self.assertEqual(result.shape, (x @ y).shape)
2160

2161
    def test_mm_fake(self):
2162
        with torch._subclasses.fake_tensor.FakeTensorMode():
2163
            x = torch.randn(2, 3, requires_grad=True, device="cpu")
2164
            y = torch.randn(3, 5, device="cpu")
2165
            result = torch.ops.aten.mm.default(x, y)
2166
            self.assertEqual(result.shape, (x @ y).shape)
2167

2168
    def test_mm_errors(self):
2169
        x = torch.randn(2, 3, requires_grad=True)
2170
        y = torch.randn(4, 5)
2171
        with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"):
2172
            result = torch.ops.aten.mm.default(x, y)
2173

2174
    def test_nonzero(self):
2175
        x = torch.tensor([0, 1, 2, 0, 0])
2176
        y = torch.ops.aten.nonzero.default(x)
2177
        self.assertEqual(y, torch.tensor([[1], [2]]))
2178

2179
    def test_inplace(self):
2180
        x = torch.randn(3)
2181
        x_clone = x.clone()
2182
        y = torch.ops.aten.sin_(x)
2183
        self.assertEqual(x, x_clone.sin())
2184

2185
    def test_incorrect_schema(self):
2186
        op = op_with_incorrect_schema(self, "incorrect_schema")
2187
        x = torch.randn(3)
2188
        op(x)
2189

2190
    def test_no_abstract(self):
2191
        op = self._op_with_no_abstract_impl
2192
        x = torch.randn(3)
2193
        op(x)
2194

2195
    def test_delayed_error(self):
2196
        op = self._op_delayed_backward_error
2197
        x = torch.randn([], requires_grad=True)
2198
        y = op(x)
2199
        with self.assertRaises(NotImplementedError):
2200
            y.sum().backward()
2201

2202
    def test_delayed_error_no_requires_grad(self):
2203
        op = self._op_delayed_backward_error
2204
        x = torch.randn([])
2205
        y = op(x)
2206

2207

2208
class TestCustomOpAPI(TestCase):
2209
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2210
    def test_basic(self):
2211
        @torch.library.custom_op("_torch_testing::add", mutates_args=())
2212
        def add(x: Tensor, y: float) -> Tensor:
2213
            x_np = x.numpy(force=True)
2214
            out_np = x_np + y
2215
            return torch.from_numpy(out_np).to(x.device)
2216

2217
        x = torch.randn(3)
2218
        y = 3.14
2219
        z = add(x, y)
2220
        self.assertEqual(z, x + y)
2221

2222
        cpu_called = False
2223

2224
        @add.register_kernel("cpu")
2225
        def _(x, y):
2226
            nonlocal cpu_called
2227
            cpu_called = True
2228
            x_np = x.numpy()
2229
            out_np = x_np + y
2230
            return torch.from_numpy(out_np)
2231

2232
        z = add(x, y)
2233
        self.assertEqual(z, x + y)
2234
        self.assertTrue(cpu_called)
2235

2236
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2237
    def test_no_grad_skips_autograd(self):
2238
        @torch.library.custom_op("_torch_testing::add", mutates_args=())
2239
        def add(x: Tensor, y: float) -> Tensor:
2240
            x_np = x.numpy(force=True)
2241
            out_np = x_np + y
2242
            return torch.from_numpy(out_np).to(x.device)
2243

2244
        called = 0
2245

2246
        def setup_context(ctx, inputs, output):
2247
            nonlocal called
2248
            called += 1
2249

2250
        def backward(ctx, grad):
2251
            raise AssertionError("should not be reached")
2252

2253
        add.register_autograd(backward, setup_context=setup_context)
2254

2255
        x = torch.randn(3, requires_grad=True)
2256
        with torch.no_grad():
2257
            y = add(x, 2.0)
2258
        self.assertEqual(called, 0)
2259
        self.assertEqual(y, x + 2.0)
2260

2261
        x.requires_grad_(False)
2262
        y = add(x, 2.0)
2263
        self.assertEqual(called, 0)
2264
        self.assertEqual(y, x + 2.0)
2265

2266
        x = torch.randn(3, requires_grad=True)
2267
        y = add(x, 2.0)
2268
        self.assertEqual(called, 1)
2269
        self.assertEqual(y, x + 2.0)
2270

2271
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2272
    def test_manual_schema(self):
2273
        @torch.library.custom_op(
2274
            "_torch_testing::add",
2275
            mutates_args=(),
2276
            schema="(Tensor x, float y) -> Tensor",
2277
        )
2278
        def add(x, y):
2279
            x_np = x.numpy(force=True)
2280
            out_np = x_np + y
2281
            return torch.from_numpy(out_np).to(x.device)
2282

2283
        x = torch.randn(3)
2284
        y = 3.14
2285
        z = add(x, y)
2286
        self.assertEqual(z, x + y)
2287

2288
        @torch.library.custom_op(
2289
            "_torch_testing::sin_",
2290
            mutates_args=["x"],
2291
            schema="(Tensor(a!) x) -> ()",
2292
        )
2293
        def sin_(x):
2294
            x_np = x.numpy()
2295
            np.sin(x_np, out=x_np)
2296

2297
        x = torch.randn(3)
2298
        expected = x.sin()
2299
        sin_(x)
2300
        self.assertEqual(x, expected)
2301

2302
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2303
    def test_kwarg_only_tensors(self):
2304
        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2305

2306
            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2307
            def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor:
2308
                pass
2309

2310
        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2311

2312
            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2313
            def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor:
2314
                pass
2315

2316
        with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2317

2318
            @torch.library.custom_op("_torch_testing::foo", mutates_args=())
2319
            def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor:
2320
                pass
2321

2322
        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2323
            lib.define("foo(Tensor x, *, Tensor y) -> Tensor")
2324
            with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2325
                torch.library.register_autograd(
2326
                    "_torch_testing::foo",
2327
                    lambda grad: grad,
2328
                    setup_context=lambda ctx, inputs, keyword_only_inputs, output: None,
2329
                )
2330

2331
            with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
2332
                torch.library.register_vmap(
2333
                    "_torch_testing::foo",
2334
                    lambda info, in_dims, x, *, y: (x, 0),
2335
                )
2336

2337
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2338
    def test_register_autograd_kwargonly_low_level(self):
2339
        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2340
            lib.define("foo(Tensor x, *, float y) -> Tensor")
2341
            called = False
2342

2343
            def foo_impl(x, *, y):
2344
                return x * y
2345

2346
            lib.impl("foo", foo_impl, "CPU")
2347

2348
            def backward(ctx, grad):
2349
                nonlocal called
2350
                called = True
2351
                return grad * ctx.y
2352

2353
            def setup_context(ctx, inputs, keyword_only_inputs, output):
2354
                assert tuple(keyword_only_inputs.keys()) == ("y",)
2355
                ctx.y = keyword_only_inputs["y"]
2356

2357
            torch.library.register_autograd(
2358
                "_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2359
            )
2360

2361
            x = torch.randn(3, requires_grad=True)
2362
            torch.ops._torch_testing.foo(x, y=3.14).sum().backward()
2363
            self.assertTrue(called)
2364
            self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14]))
2365

2366
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2367
    def test_register_autograd_defaults(self):
2368
        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2369
            lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
2370

2371
            def foo_impl(w, x=2, *, y=3, z):
2372
                return w * x * y * z
2373

2374
            lib.impl("foo", foo_impl, "CPU")
2375

2376
            called = False
2377

2378
            def backward(ctx, grad):
2379
                nonlocal called
2380
                called = True
2381
                return grad * ctx.c
2382

2383
            def setup_context(ctx, inputs, keyword_only_inputs, output):
2384
                assert len(inputs) == 2
2385
                assert inputs[1] == 2
2386
                assert keyword_only_inputs == {"y": 3, "z": 42}
2387
                ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1]
2388

2389
            torch.library.register_autograd(
2390
                "_torch_testing::foo", backward, setup_context=setup_context, lib=lib
2391
            )
2392

2393
            w = torch.randn(3, requires_grad=True)
2394
            torch.ops._torch_testing.foo(w, z=42).sum().backward()
2395
            self.assertTrue(called)
2396
            self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42))
2397

2398
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2399
    def test_manual_schema_error(self):
2400
        with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"):
2401

2402
            @torch.library.custom_op(
2403
                "_torch_testing::sin_",
2404
                mutates_args=(),
2405
                schema="(Tensor(a!) x) -> ()",
2406
            )
2407
            def sin_(x):
2408
                x_np = x.numpy()
2409
                np.sin(x_np, out=x_np)
2410

2411
    def test_supports_tensorlist(self):
2412
        @torch._library.autograd.supports_tensorlist
2413
        class Stack(torch.autograd.Function):
2414
            @staticmethod
2415
            def forward(ctx, xs):
2416
                ctx.num_xs = len(xs)
2417
                return torch.stack(xs)
2418

2419
            @staticmethod
2420
            def backward(ctx, grad):
2421
                expected = ([True] * ctx.num_xs,)
2422
                self.assertEqual(ctx.needs_input_grad, expected)
2423
                return list(grad.unbind(0))
2424

2425
        # call two applys, do a backward on the first
2426
        def t():
2427
            return torch.randn([], requires_grad=True)
2428

2429
        xs0 = [t(), t(), t()]
2430
        xs1 = [t(), t(), t(), t()]
2431
        y0 = Stack.apply(xs0)
2432
        y1 = Stack.apply(xs1)
2433
        grads = torch.autograd.grad(y0.sum(), xs0)
2434
        self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2435

2436
        # call one apply, do multiple backwards
2437
        xs = [t(), t(), t()]
2438
        y = Stack.apply(xs)
2439
        _ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2440
        _ = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2441
        grads = torch.autograd.grad(y.sum(), xs, retain_graph=True)
2442
        self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)])
2443

2444
        # error: on access forward, backward directly
2445
        with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"):
2446
            Stack.forward(None, xs)
2447
        with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"):
2448
            Stack.backward(None, xs)
2449

2450
        # the recursive case
2451
        @torch._library.autograd.supports_tensorlist
2452
        class Foo(torch.autograd.Function):
2453
            @staticmethod
2454
            def forward(ctx, xs):
2455
                if len(xs) > 1:
2456
                    return Foo.apply(xs[1:])
2457
                ctx.len_xs = len(xs)
2458
                return xs[0].sin()
2459

2460
            @staticmethod
2461
            def backward(ctx, grad):
2462
                result = [None] * ctx.len_xs
2463
                result[-1] = grad.cos()
2464
                return result
2465

2466
        # should work
2467
        result = Foo.apply(xs)
2468
        expected = xs[-1].sin()
2469
        self.assertEqual(result, expected)
2470

2471
        # recursive on backward
2472
        @torch._library.autograd.supports_tensorlist
2473
        class Bar(torch.autograd.Function):
2474
            @staticmethod
2475
            def forward(ctx, xs):
2476
                return [xs[i] + i for i in range(len(xs))]
2477

2478
            @staticmethod
2479
            def backward(ctx, grads):
2480
                f1 = Bar.apply(grads[:2])
2481
                f2 = Bar.apply(grads[2:])
2482
                return f1 + f2
2483

2484
        xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)]
2485
        ys = Bar.apply(xs)
2486
        sum(ys).backward()
2487
        result = [xi.grad for xi in xs]
2488
        self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0))
2489

2490
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2491
    def test_default_values(self):
2492
        defaults = []
2493

2494
        @torch.library.custom_op("_torch_testing::f", mutates_args=())
2495
        def f(
2496
            x: Tensor,
2497
            a: Optional[int] = None,
2498
            b: float = 3.14,
2499
            c: bool = True,
2500
            d: int = 3,
2501
            e: str = "foo",
2502
            f: torch.dtype = torch.float,
2503
            g: torch.dtype = torch.float32,
2504
            h: torch.dtype = torch.int,
2505
            i: torch.device = torch.device("cpu:0"),
2506
            j: torch.device = "cpu",
2507
        ) -> Tensor:
2508
            defaults.extend([a, b, c, d, e, f, g, h, i, j])
2509
            return x.clone()
2510

2511
        x = torch.randn(3)
2512
        f(x)
2513
        self.assertEqual(
2514
            defaults,
2515
            [
2516
                None,
2517
                3.14,
2518
                True,
2519
                3,
2520
                "foo",
2521
                torch.float,
2522
                torch.float32,
2523
                torch.int,
2524
                torch.device("cpu:0"),
2525
                "cpu",
2526
            ],
2527
        )
2528
        default_values = [
2529
            arg.default_value
2530
            for arg in torch.ops._torch_testing.f.default._schema.arguments
2531
        ]
2532
        # enum values taken from c10/core/ScalarType.h
2533
        type_enum = {
2534
            "float": 6,
2535
            "int": 3,
2536
        }
2537
        self.assertEqual(
2538
            default_values,
2539
            [
2540
                None,
2541
                None,
2542
                3.14,
2543
                True,
2544
                3,
2545
                "foo",
2546
                type_enum["float"],
2547
                type_enum["float"],
2548
                type_enum["int"],
2549
                torch.device("cpu:0"),
2550
                torch.device("cpu"),
2551
            ],
2552
        )
2553

2554
    def test_mutated_error(self):
2555
        with self.assertRaisesRegex(
2556
            ValueError, r".*{'y'} in mutates_args were not found"
2557
        ):
2558

2559
            @torch.library.custom_op(
2560
                "_torch_testing::numpy_sin_inplace",
2561
                mutates_args={"y"},
2562
                device_types="cpu",
2563
            )
2564
            def numpy_sin_inplace(x: Tensor) -> None:
2565
                x_np = x.numpy()
2566
                np.sin(x_np, out=x_np)
2567

2568
    def test_mutated(self):
2569
        @torch.library.custom_op(
2570
            "_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu"
2571
        )
2572
        def numpy_sin_inplace(x: Tensor) -> None:
2573
            x_np = x.numpy()
2574
            np.sin(x_np, out=x_np)
2575

2576
        x = torch.randn(3)
2577
        version = x._version
2578
        expected = x.sin()
2579
        numpy_sin_inplace(x)
2580
        self.assertEqual(x, expected)
2581
        self.assertGreater(x._version, version)
2582

2583
        @torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"})
2584
        def f(
2585
            x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2586
        ) -> None:
2587
            return
2588

2589
        x = torch.randn(3)
2590
        y = torch.randn(3)
2591
        z = [torch.randn(3), torch.randn(3)]
2592
        w = [torch.randn(3), None, torch.randn(3)]
2593
        initial_versions = pytree.tree_map_only(
2594
            torch.Tensor, lambda x: x._version, (x, y, z, w)
2595
        )
2596
        f(x, y, z, w)
2597
        new_versions = pytree.tree_map_only(
2598
            torch.Tensor, lambda x: x._version, (x, y, z, w)
2599
        )
2600

2601
        self.assertEqual(initial_versions[0], new_versions[0])
2602
        initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
2603
        new_versions, _ = pytree.tree_flatten(new_versions[1:])
2604
        for prev, after in zip(initial_versions, new_versions):
2605
            if prev is None and after is None:
2606
                continue
2607
            self.assertGreater(after, prev)
2608

2609
    def test_mutated_unknown(self):
2610
        @torch.library.custom_op(
2611
            "_torch_testing::f", mutates_args="unknown", device_types="cpu"
2612
        )
2613
        def f(x: Tensor) -> None:
2614
            x_np = x.numpy()
2615
            np.sin(x_np, out=x_np)
2616

2617
        x = torch.randn(3)
2618
        version = x._version
2619
        expected = x.sin()
2620
        f(x)
2621
        self.assertEqual(x, expected)
2622
        self.assertGreater(x._version, version)
2623

2624
        @torch.library.custom_op("_torch_testing::f2", mutates_args="unknown")
2625
        def f2(
2626
            x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
2627
        ) -> None:
2628
            return
2629

2630
        x = torch.randn(3)
2631
        y = torch.randn(3)
2632
        z = [torch.randn(3), torch.randn(3)]
2633
        w = [torch.randn(3), None, torch.randn(3)]
2634
        initial_versions = pytree.tree_map_only(
2635
            torch.Tensor, lambda x: x._version, (x, y, z, w)
2636
        )
2637
        f2(x, y, z, w)
2638
        new_versions = pytree.tree_map_only(
2639
            torch.Tensor, lambda x: x._version, (x, y, z, w)
2640
        )
2641

2642
        initial_versions, _ = pytree.tree_flatten(initial_versions)
2643
        new_versions, _ = pytree.tree_flatten(new_versions)
2644
        for prev, after in zip(initial_versions, new_versions):
2645
            if prev is None and after is None:
2646
                continue
2647
            self.assertGreater(after, prev)
2648

2649
        with self.assertRaisesRegex(ValueError, "string"):
2650

2651
            @torch.library.custom_op("_torch_testing::f3", mutates_args="x")
2652
            def f3(x: Tensor) -> None:
2653
                return
2654

2655
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2656
    def test_library_register_torch_dispatch_rule_subclass(self):
2657
        from torch.testing._internal.two_tensor import TwoTensor
2658

2659
        @torch.library.custom_op("mylib::foo", mutates_args={})
2660
        def f(x: torch.Tensor) -> torch.Tensor:
2661
            return x.sin()
2662

2663
        x = torch.randn(3)
2664
        y = torch.randn(3)
2665
        z = TwoTensor(x, y)
2666

2667
        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2668
            called = 0
2669

2670
            def TwoTensor_foo(cls, func, types, args, kwargs):
2671
                nonlocal called
2672
                assert cls is TwoTensor
2673
                called += 1
2674
                return x.sin()
2675

2676
            m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo)
2677

2678
            out = f(z)
2679
            out2 = z.cos()
2680

2681
        self.assertEqual(called, 1)
2682

2683
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2684
    def test_library_register_torch_dispatch_rule_mode(self):
2685
        from torch.testing._internal.two_tensor import TwoTensorMode
2686

2687
        @torch.library.custom_op("mylib::foo", mutates_args={})
2688
        def f(x: torch.Tensor) -> torch.Tensor:
2689
            return x.sin()
2690

2691
        x = torch.randn(3)
2692

2693
        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
2694
            called = 0
2695

2696
            def TwoTensor_foo(mode, func, types, args, kwargs):
2697
                nonlocal called
2698
                called += 1
2699
                return x.sin()
2700

2701
            m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo)
2702

2703
            with TwoTensorMode():
2704
                out = f(x)
2705
                out2 = x.cos()
2706

2707
        self.assertEqual(called, 1)
2708

2709
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2710
    @parametrize("idx", [0, 1, 2, 3, 4, 5])
2711
    def test_library_register_fake_source(self, idx):
2712
        opname = f"source{idx}"
2713
        op = getattr(torch.ops._torch_testing, opname).default
2714
        entry = torch._library.simple_registry.singleton.find(op._name)
2715
        source = entry.fake_impl.kernel.source
2716
        assert source is not None
2717
        self.assertTrue("custom_op_db.py" in source)
2718

2719
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2720
    def test_library_register_fake(self):
2721
        for mode in ["function", "qualname", "opoverload"]:
2722

2723
            @torch.library.custom_op("_torch_testing::add", mutates_args=())
2724
            def add(x: Tensor, y: float) -> Tensor:
2725
                x_np = x.cpu().numpy()
2726
                out_np = x_np + y
2727
                return torch.from_numpy(out_np).to(x.device)
2728

2729
            called = False
2730

2731
            if mode == "function":
2732
                dec = torch.library.register_fake(add)
2733
                self.assertIsNotNone(dec)
2734
            elif mode == "qualname":
2735
                dec = torch.library.register_fake("_torch_testing::add")
2736
                self.assertIsNotNone(dec)
2737
            elif mode == "opoverload":
2738
                dec = torch.library.register_fake(torch.ops._torch_testing.add.default)
2739
                self.assertIsNotNone(dec)
2740
            else:
2741
                raise AssertionError("should not get here")
2742

2743
            @dec
2744
            def _(x, y):
2745
                nonlocal called
2746
                called = True
2747
                return torch.empty_like(x)
2748

2749
            with torch._subclasses.fake_tensor.FakeTensorMode():
2750
                x = torch.randn(3)
2751
                y = 3.14
2752
                z = add(x, y)
2753
                self.assertEqual(z.shape, x.shape)
2754
                self.assertTrue(called)
2755

2756
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2757
    def test_library_register_torch_dispatch(self):
2758
        for mode in ["function", "qualname", "opoverload"]:
2759

2760
            class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2761
                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2762
                    return func(*args, **kwargs)
2763

2764
            @torch.library.custom_op("_torch_testing::add", mutates_args=())
2765
            def add(x: Tensor, y: float) -> Tensor:
2766
                x_np = x.cpu().numpy()
2767
                out_np = x_np + y
2768
                return torch.from_numpy(out_np).to(x.device)
2769

2770
            called = False
2771

2772
            if mode == "function":
2773
                dec = torch.library.register_torch_dispatch(add, MyMode)
2774
                self.assertIsNotNone(dec)
2775
            elif mode == "qualname":
2776
                dec = torch.library.register_torch_dispatch(
2777
                    "_torch_testing::add", MyMode
2778
                )
2779
                self.assertIsNotNone(dec)
2780
            elif mode == "opoverload":
2781
                dec = torch.library.register_torch_dispatch(
2782
                    torch.ops._torch_testing.add.default, MyMode
2783
                )
2784
                self.assertIsNotNone(dec)
2785
            else:
2786
                raise AssertionError("should not get here")
2787

2788
            @dec
2789
            def _(mode, func, types, args, kwargs):
2790
                nonlocal called
2791
                called = True
2792
                return func(*args, **kwargs)
2793

2794
            with MyMode():
2795
                x = torch.randn(3)
2796
                y = 3.14
2797
                z = add(x, y)
2798
                self.assertEqual(z.shape, x.shape)
2799
                self.assertTrue(called)
2800

2801
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2802
    def test_library_register_torch_dispatch_low_level(self):
2803
        modes = ["qualname", "opoverload"]
2804
        calls = ["decorator", "function"]
2805
        device_types_options = [("cpu", "cuda"), "cpu", None]
2806

2807
        for mode, call, device_types in itertools.product(
2808
            modes, calls, device_types_options
2809
        ):
2810
            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2811
                lib.define("add10(Tensor x, float y) -> Tensor")
2812

2813
                if mode == "qualname":
2814
                    op = "_torch_testing::add10"
2815
                else:
2816
                    assert mode == "opoverload"
2817
                    op = torch.ops._torch_testing.add10.default
2818

2819
                called = False
2820

2821
                class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
2822
                    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
2823
                        return func(*args, **kwargs)
2824

2825
                if call == "decorator":
2826

2827
                    @torch.library.register_torch_dispatch(op, MyMode, lib=lib)
2828
                    def _(mode, func, types, args, kwargs):
2829
                        x, y = args
2830
                        nonlocal called
2831
                        called = True
2832
                        return x + y
2833

2834
                else:
2835
                    assert call == "function"
2836

2837
                    def add_stuff(mode, func, types, args, kwargs):
2838
                        x, y = args
2839
                        nonlocal called
2840
                        called = True
2841
                        return x + y
2842

2843
                    torch.library.register_torch_dispatch(
2844
                        op, MyMode, add_stuff, lib=lib
2845
                    )
2846

2847
                x = torch.randn(3)
2848
                y = 3.14
2849
                with MyMode():
2850
                    z = torch.ops._torch_testing.add10.default(x, y)
2851
                self.assertEqual(z, x + y)
2852
                self.assertTrue(called)
2853

2854
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2855
    def test_library_register_kernel(self):
2856
        modes = ["function", "qualname", "opoverload"]
2857
        calls = ["decorator", "function"]
2858
        device_types_options = ["cpu", None]
2859

2860
        for mode, call, device_types in itertools.product(
2861
            modes, calls, device_types_options
2862
        ):
2863

2864
            @torch.library.custom_op(
2865
                "_torch_testing::add", mutates_args=(), device_types="cuda"
2866
            )
2867
            def add(x: Tensor, y: float) -> Tensor:
2868
                x_np = x.cpu().numpy()
2869
                out_np = x_np + y
2870
                return torch.from_numpy(out_np).to(x.device)
2871

2872
            if mode == "function":
2873
                op = add
2874
            elif mode == "qualname":
2875
                op = "_torch_testing::add"
2876
            else:
2877
                assert mode == "opoverload"
2878
                op = torch.ops._torch_testing.add.default
2879

2880
            called = False
2881

2882
            if call == "decorator":
2883

2884
                @torch.library.register_kernel(op, device_types)
2885
                def _(x, y):
2886
                    nonlocal called
2887
                    called = True
2888
                    x_np = x.numpy()
2889
                    out_np = x_np + y
2890
                    return torch.from_numpy(out_np)
2891

2892
            else:
2893
                assert call == "function"
2894

2895
                def add_cpu(x, y):
2896
                    nonlocal called
2897
                    called = True
2898
                    x_np = x.numpy()
2899
                    out_np = x_np + y
2900
                    return torch.from_numpy(out_np)
2901

2902
                torch.library.register_kernel(op, device_types, add_cpu)
2903

2904
            x = torch.randn(3)
2905
            y = 3.14
2906
            z = add(x, y)
2907
            self.assertEqual(z, x + y)
2908
            self.assertTrue(called)
2909

2910
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2911
    def test_library_register_kernel_low_level(self):
2912
        modes = ["qualname", "opoverload"]
2913
        calls = ["decorator", "function"]
2914
        device_types_options = [("cpu", "cuda"), "cpu", None]
2915

2916
        for mode, call, device_types in itertools.product(
2917
            modes, calls, device_types_options
2918
        ):
2919
            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2920
                lib.define("add9(Tensor x, float y) -> Tensor")
2921

2922
                if mode == "qualname":
2923
                    op = "_torch_testing::add9"
2924
                else:
2925
                    assert mode == "opoverload"
2926
                    op = torch.ops._torch_testing.add9.default
2927

2928
                called = False
2929

2930
                if call == "decorator":
2931

2932
                    @torch.library.register_kernel(op, device_types, lib=lib)
2933
                    def _(x, y):
2934
                        nonlocal called
2935
                        called = True
2936
                        x_np = x.numpy()
2937
                        out_np = x_np + y
2938
                        return torch.from_numpy(out_np)
2939

2940
                else:
2941
                    assert call == "function"
2942

2943
                    def add_cpu(x, y):
2944
                        nonlocal called
2945
                        called = True
2946
                        x_np = x.numpy()
2947
                        out_np = x_np + y
2948
                        return torch.from_numpy(out_np)
2949

2950
                    torch.library.register_kernel(op, device_types, add_cpu, lib=lib)
2951

2952
                x = torch.randn(3)
2953
                y = 3.14
2954
                z = torch.ops._torch_testing.add9.default(x, y)
2955
                self.assertEqual(z, x + y)
2956
                self.assertTrue(called)
2957

2958
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2959
    def test_library_register_autograd(self):
2960
        for mode in ["function", "qualname", "opoverload"]:
2961

2962
            @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
2963
            def numpy_sin(x: Tensor) -> Tensor:
2964
                x_np = x.cpu().numpy()
2965
                y_np = np.sin(x_np)
2966
                return torch.from_numpy(y_np).to(device=x.device)
2967

2968
            def setup_context(ctx, inputs, output) -> Tensor:
2969
                (x,) = inputs
2970
                ctx.save_for_backward(x)
2971

2972
            called = False
2973

2974
            def backward(ctx, grad):
2975
                nonlocal called
2976
                called = True
2977
                (x,) = ctx.saved_tensors
2978
                return grad * x.cos()
2979

2980
            if mode == "function":
2981
                torch.library.register_autograd(
2982
                    numpy_sin, backward, setup_context=setup_context
2983
                )
2984
            elif mode == "qualname":
2985
                torch.library.register_autograd(
2986
                    "mylib::numpy_sin", backward, setup_context=setup_context
2987
                )
2988
            elif mode == "opoverload":
2989
                torch.library.register_autograd(
2990
                    torch.ops.mylib.numpy_sin.default,
2991
                    backward,
2992
                    setup_context=setup_context,
2993
                )
2994

2995
            x = torch.randn(3, requires_grad=True)
2996
            y = numpy_sin(x)
2997
            (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
2998
            self.assertTrue(called)
2999
            self.assertEqual(grad_x, x.cos())
3000

3001
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3002
    def test_library_register_autograd_low_level(self):
3003
        for mode in ["qualname", "opoverload"]:
3004
            with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3005
                lib.define("sin5(Tensor x) -> Tensor")
3006

3007
                def numpy_sin(x: Tensor) -> Tensor:
3008
                    x_np = x.cpu().detach().numpy()
3009
                    y_np = np.sin(x_np)
3010
                    return torch.from_numpy(y_np).to(device=x.device)
3011

3012
                def setup_context(ctx, inputs, output) -> Tensor:
3013
                    (x,) = inputs
3014
                    ctx.save_for_backward(x)
3015

3016
                called = False
3017

3018
                def backward(ctx, grad):
3019
                    nonlocal called
3020
                    called = True
3021
                    (x,) = ctx.saved_tensors
3022
                    return grad * x.cos()
3023

3024
                lib.impl("sin5", numpy_sin, "CPU")
3025

3026
                called = False
3027

3028
                if mode == "qualname":
3029
                    torch.library.register_autograd(
3030
                        "_torch_testing::sin5",
3031
                        backward,
3032
                        setup_context=setup_context,
3033
                        lib=lib,
3034
                    )
3035
                elif mode == "opoverload":
3036
                    torch.library.register_autograd(
3037
                        torch.ops._torch_testing.sin5.default,
3038
                        backward,
3039
                        setup_context=setup_context,
3040
                        lib=lib,
3041
                    )
3042
                x = torch.randn(3, requires_grad=True)
3043
                y = torch.ops._torch_testing.sin5(x)
3044
                (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
3045
                self.assertTrue(called)
3046
                self.assertEqual(grad_x, x.cos())
3047

3048
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3049
    def test_fake(self):
3050
        @torch.library.custom_op("_torch_testing::add", mutates_args=())
3051
        def add(x: Tensor, y: float) -> Tensor:
3052
            x_np = x.cpu().numpy()
3053
            out_np = x_np + y
3054
            return torch.from_numpy(out_np).to(x.device)
3055

3056
        x = torch.randn(3)
3057
        y = 3.14
3058
        z = add(x, y)
3059
        self.assertEqual(z, x + y)
3060

3061
        try:
3062
            with torch._subclasses.fake_tensor.FakeTensorMode():
3063
                x = torch.randn(3)
3064
                add(x, y)
3065
            raise AssertionError("should not be hit")
3066
        except RuntimeError as e:
3067
            abstract_impl_error_msg = str(e)
3068
        abstract_impl_error_msg = re.sub(
3069
            r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg
3070
        ).replace(". ", ".\n")
3071
        self.assertExpectedInline(
3072
            abstract_impl_error_msg,
3073
            """\
3074
There was no fake impl registered for <CustomOpDef(_torch_testing::add)>.
3075
This is necessary for torch.compile/export/fx tracing to work.
3076
Please use `add.register_fake` to add an fake impl.""",
3077
        )
3078

3079
        if not IS_WINDOWS:
3080

3081
            @torch.compile(backend="eager")
3082
            def f(x, y):
3083
                return add(x, y)
3084

3085
            x = torch.randn(3)
3086
            with self.assertRaisesRegex(RuntimeError, "no fake impl"):
3087
                f(x, y)
3088

3089
        abstract_called = False
3090

3091
        @add.register_fake
3092
        def _(x, y):
3093
            nonlocal abstract_called
3094
            abstract_called = True
3095
            return torch.empty_like(x)
3096

3097
        with torch._subclasses.fake_tensor.FakeTensorMode():
3098
            x = torch.randn(3)
3099
            z = add(x, y)
3100
            self.assertEqual(z.shape, x.shape)
3101
            self.assertTrue(abstract_called)
3102

3103
    @skipIfTorchDynamo("recursive dynamo")
3104
    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows")
3105
    def test_compile(self):
3106
        called_impl = False
3107
        called_abstract = False
3108

3109
        @torch.library.custom_op("_torch_testing::linear", mutates_args=())
3110
        def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
3111
            nonlocal called_impl
3112
            called_impl = True
3113
            x_np = x.numpy()
3114
            w_np = weight.numpy()
3115
            b_np = bias.numpy()
3116
            out_np = np.add(x_np @ w_np.T, bias)
3117
            return out_np
3118

3119
        @custom_linear.register_fake
3120
        def _(x, weight, bias):
3121
            nonlocal called_abstract
3122
            called_abstract = True
3123
            assert x.dim() == 2
3124
            assert weight.dim() == 2
3125
            assert bias.dim() == 1
3126
            assert x.shape[1] == weight.shape[1]
3127
            assert weight.shape[0] == bias.shape[0]
3128
            assert x.device == weight.device
3129
            return x.new_empty(x.size(0), weight.size(0))
3130

3131
        x = torch.randn(2, 2)
3132
        weight = torch.randn(2, 2)
3133
        bias = torch.randn(2)
3134
        out = torch.compile(custom_linear, backend="eager", fullgraph=True)(
3135
            x, weight, bias
3136
        )
3137
        self.assertEqual(out, torch.nn.functional.linear(x, weight, bias))
3138
        self.assertTrue(called_impl)
3139
        self.assertTrue(called_abstract)
3140

3141
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3142
    def test_register_autograd_error_cases(self):
3143
        @torch.library.custom_op("_torch_testing::g", mutates_args=())
3144
        def g(x: Tensor) -> Tensor:
3145
            return x.sin()
3146

3147
        x = torch.randn(3, requires_grad=True)
3148
        y = g(x)
3149
        with self.assertRaisesRegex(RuntimeError, "no autograd formula"):
3150
            y.sum().backward()
3151

3152
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3153
    def test_replacement(self):
3154
        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3155
        def f(x: Tensor) -> Tensor:
3156
            return x.sin()
3157

3158
        x = torch.randn(3)
3159
        y = f(x)
3160
        self.assertEqual(y, x.sin())
3161

3162
        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3163
        def f(x: Tensor) -> Tensor:
3164
            return x.cos()
3165

3166
        y = f(x)
3167
        self.assertEqual(y, x.cos())
3168

3169
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3170
    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
3171
    def test_split_device(self):
3172
        cpu_call_count = 0
3173
        cuda_call_count = 0
3174

3175
        @torch.library.custom_op(
3176
            "_torch_testing::f", mutates_args=(), device_types="cpu"
3177
        )
3178
        def f(x: Tensor) -> Tensor:
3179
            nonlocal cpu_call_count
3180
            cpu_call_count += 1
3181
            x_np = x.numpy()
3182
            out_np = np.sin(x_np)
3183
            return torch.from_numpy(out_np)
3184

3185
        @f.register_kernel("cuda")
3186
        def _(x: Tensor) -> Tensor:
3187
            nonlocal cuda_call_count
3188
            cuda_call_count += 1
3189
            x_np = x.cpu().numpy()
3190
            out_np = np.sin(x_np)
3191
            return torch.from_numpy(out_np).to(x.device)
3192

3193
        x = torch.randn(3)
3194
        y = f(x)
3195
        self.assertEqual(y, x.sin())
3196
        self.assertEqual(cpu_call_count, 1)
3197
        self.assertEqual(cuda_call_count, 0)
3198

3199
        x = x.cuda()
3200
        y = f(x)
3201
        self.assertEqual(y, x.sin())
3202
        self.assertEqual(cpu_call_count, 1)
3203
        self.assertEqual(cuda_call_count, 1)
3204

3205
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3206
    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
3207
    def test_multi_types(self):
3208
        @torch.library.custom_op(
3209
            "_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda")
3210
        )
3211
        def f(x: Tensor) -> Tensor:
3212
            x_np = x.cpu().numpy()
3213
            out_np = np.sin(x_np)
3214
            return torch.from_numpy(out_np).to(x.device)
3215

3216
        x = torch.randn(3)
3217
        y = f(x)
3218
        self.assertEqual(y, x.sin())
3219
        x = x.cuda()
3220
        y = f(x)
3221
        self.assertEqual(y, x.sin())
3222

3223
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3224
    def test_overloading(self):
3225
        called_f = 0
3226
        called_f1 = 0
3227

3228
        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3229
        def f(x: Tensor) -> Tensor:
3230
            nonlocal called_f
3231
            called_f += 1
3232
            return x.clone()
3233

3234
        x = torch.randn(2, 3)
3235
        torch.ops._torch_testing.f(x)
3236
        self.assertEqual(called_f, 1)
3237

3238
        @torch.library.custom_op("_torch_testing::f.overload", mutates_args=())
3239
        def f1(x: Tensor, y: Tensor) -> Tensor:
3240
            nonlocal called_f1
3241
            called_f1 += 1
3242
            return x.clone()
3243

3244
        torch.ops._torch_testing.f(x, x)
3245
        self.assertEqual(called_f1, 1)
3246

3247
    def test_disallows_output_aliasing(self):
3248
        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3249
        def f(x: Tensor) -> Tensor:
3250
            return x.view(-1)
3251

3252
        x = torch.randn(3)
3253
        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3254
            f(x)
3255

3256
        @torch.library.custom_op("_torch_testing::f", mutates_args=())
3257
        def f(x: Tensor) -> Tensor:
3258
            return x
3259

3260
        x = torch.randn(3)
3261
        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3262
            f(x)
3263

3264
        @torch.library.custom_op(
3265
            "_torch_testing::f", mutates_args={"x"}, device_types="cpu"
3266
        )
3267
        def numpy_sin_inplace(x: Tensor) -> Tensor:
3268
            x_np = x.numpy()
3269
            np.sin(x_np, out=x_np)
3270
            return x
3271

3272
        x = torch.randn(3)
3273
        with self.assertRaisesRegex(RuntimeError, "may not alias"):
3274
            numpy_sin_inplace(x)
3275

3276
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3277
    def test_factory_function(self):
3278
        @torch.library.custom_op(
3279
            "_torch_testing::f", mutates_args={}, device_types="cpu"
3280
        )
3281
        def f(device: torch.device) -> Tensor:
3282
            return torch.ones(3)
3283

3284
        result = f(device="cpu")
3285
        self.assertEqual(result.device, torch.device("cpu"))
3286
        self.assertEqual(result, torch.ones(3))
3287

3288
        with self.assertRaisesRegex(
3289
            RuntimeError, "f does not have a kernel registered for cuda"
3290
        ):
3291
            f("cuda")
3292

3293
        with self.assertRaisesRegex(
3294
            ValueError,
3295
            "Functions without tensor inputs are required to have a `device: torch.device` argument",
3296
        ):
3297

3298
            @torch.library.custom_op(
3299
                "_torch_testing::f2", mutates_args={}, device_types="cpu"
3300
            )
3301
            def f2() -> Tensor:
3302
                return torch.ones(3)
3303

3304
        @torch.library.custom_op("_torch_testing::f3", mutates_args={})
3305
        def f3() -> Tensor:
3306
            raise NotImplementedError("NYI")
3307

3308
        with self.assertRaisesRegex(
3309
            ValueError,
3310
            "Functions without tensor inputs are required to have a `device: torch.device` argument",
3311
        ):
3312

3313
            @f3.register_kernel("cpu")
3314
            def _():
3315
                return torch.zeros(3)
3316

3317
            result = f(x)
3318

3319
        @torch.library.custom_op("_torch_testing::f4", mutates_args={})
3320
        def f4(device: torch.device) -> Tensor:
3321
            raise NotImplementedError("NYI")
3322

3323
        @f4.register_kernel("cpu")
3324
        def _(device: torch.device):
3325
            return torch.zeros(3)
3326

3327
        result = f(device="cpu")
3328
        self.assertEqual(result.device, torch.device("cpu"))
3329
        self.assertEqual(result, torch.ones(3))
3330

3331
    def test_library_schema_infer(self):
3332
        def foo_impl(x: torch.Tensor) -> torch.Tensor:
3333
            return x.sin()
3334

3335
        schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={})
3336
        self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor")
3337

3338
        schema = torch.library.infer_schema(foo_impl, mutates_args={})
3339
        self.assertExpectedInline(schema, "(Tensor x) -> Tensor")
3340

3341
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3342
    def test_set_kernel_enabled(self):
3343
        x = torch.ones(1)
3344

3345
        @torch.library.custom_op("mylib::f", mutates_args=())
3346
        def f(x: Tensor) -> Tensor:
3347
            return x + 1
3348

3349
        self.assertEqual(f(x), x + 1)
3350
        with self.assertLogs("torch._library.custom_ops") as captured:
3351
            with f.set_kernel_enabled("gpu", enabled=False):
3352
                self.assertEqual(f(x), x + 1)
3353
            self.assertIn(
3354
                "no kernel was registered for this device type", captured.output[0]
3355
            )
3356

3357
        @f.register_kernel("cpu")
3358
        def _(x):
3359
            return x + 2
3360

3361
        self.assertEqual(f(x), x + 2)
3362

3363
        with self.assertLogs("torch._library.custom_ops") as captured:
3364
            with f.set_kernel_enabled("cpu", enabled=True):
3365
                self.assertEqual(f(x), x + 2)
3366
            self.assertIn("already enabled", captured.output[0])
3367

3368
        with f.set_kernel_enabled("cpu", enabled=False):
3369
            self.assertEqual(f(x), x + 1)
3370

3371
            with self.assertLogs("torch._library.custom_ops") as captured:
3372
                with f.set_kernel_enabled("cpu", enabled=False):
3373
                    self.assertEqual(f(x), x + 1)
3374
                self.assertIn("already disabled", captured.output[0])
3375

3376
            self.assertEqual(f(x), x + 1)
3377

3378
        with f.set_kernel_enabled("cpu", enabled=True):
3379
            self.assertEqual(f(x), x + 2)
3380

3381
        with f.set_kernel_enabled("cpu", enabled=False):
3382
            self.assertEqual(f(x), x + 1)
3383

3384
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3385
    def test_register_vmap_kwargonly_low_level(self):
3386
        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3387
            lib.define("foo(Tensor x, *, float y) -> Tensor")
3388
            called = False
3389

3390
            def foo_impl(x, *, y):
3391
                return x * y
3392

3393
            lib.impl("foo", foo_impl, "CPU")
3394

3395
            def vmap(info, in_dims, x, *, y):
3396
                nonlocal called
3397
                called = True
3398
                return x * y, 0
3399

3400
            torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3401

3402
            x = torch.ones(3)
3403
            result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14)
3404
            self.assertTrue(called)
3405
            self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14]))
3406

3407
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3408
    def test_register_vmap_defaults(self):
3409
        with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
3410
            lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
3411

3412
            def foo_impl(w, x=2, *, y=3, z):
3413
                return w * x * y * z
3414

3415
            lib.impl("foo", foo_impl, "CPU")
3416

3417
            called = False
3418

3419
            def vmap(info, in_dims, w, x=2, *, y=3, z):
3420
                nonlocal called
3421
                called = True
3422
                return w * x * y * z, 0
3423

3424
            torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
3425

3426
            w = torch.ones(3)
3427
            result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42)
3428
            self.assertTrue(called)
3429
            self.assertEqual(result, w * 2 * 3 * 42)
3430

3431
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3432
    def test_library_register_vmap(self):
3433
        for mode in ["function", "qualname", "opoverload", "c_opdef"]:
3434

3435
            @torch.library.custom_op("mylib::f", mutates_args=())
3436
            def f(x: Tensor, y: Tensor) -> Tensor:
3437
                return x * y
3438

3439
            called = False
3440

3441
            def fvmap(info, in_dims, x, y):
3442
                nonlocal called
3443
                called = True
3444
                x_bdim, y_bdim = in_dims
3445
                x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3446
                y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3447
                result = x * y
3448
                result = result.movedim(-1, 0)
3449
                return result, 0
3450

3451
            if mode == "function":
3452
                torch.library.register_vmap(f, fvmap)
3453
            elif mode == "qualname":
3454
                torch.library.register_vmap("mylib::f", fvmap)
3455
            elif mode == "opoverload":
3456
                torch.library.register_vmap(torch.ops.mylib.f.default, fvmap)
3457
            elif mode == "c_opdef":
3458
                f.register_vmap(fvmap)
3459

3460
            x = torch.randn(2, 2)
3461
            y = torch.randn(2, 2)
3462

3463
            result = torch.vmap(f)(x, y)
3464
            self.assertTrue(called)
3465
            self.assertEqual(result, x * y)
3466

3467
            called = False
3468
            result = torch.vmap(f, out_dims=1)(x, y)
3469
            self.assertEqual(result, (x * y).T)
3470
            self.assertTrue(called)
3471

3472
            called = False
3473
            result = torch.vmap(f, in_dims=1)(x, y)
3474
            self.assertEqual(result, (x * y).T)
3475
            self.assertTrue(called)
3476

3477
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3478
    def test_library_register_vmap_library_decorator(self):
3479
        @torch.library.custom_op("mylib::f", mutates_args=())
3480
        def f(x: Tensor, y: Tensor) -> Tensor:
3481
            return x * y
3482

3483
        called = False
3484

3485
        @torch.library.register_vmap("mylib::f")
3486
        def fvmap(info, in_dims, x, y):
3487
            nonlocal called
3488
            called = True
3489
            x_bdim, y_bdim = in_dims
3490
            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3491
            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3492
            result = x * y
3493
            result = result.movedim(-1, 0)
3494
            return result, 0
3495

3496
        x = torch.randn(2, 2)
3497
        y = torch.randn(2, 2)
3498

3499
        result = torch.vmap(f)(x, y)
3500
        self.assertTrue(called)
3501
        self.assertEqual(result, x * y)
3502

3503
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3504
    def test_library_register_vmap_op_decorator(self):
3505
        @torch.library.custom_op("mylib::f", mutates_args=())
3506
        def f(x: Tensor, y: Tensor) -> Tensor:
3507
            return x * y
3508

3509
        called = False
3510

3511
        @f.register_vmap
3512
        def fvmap(info, in_dims, x, y):
3513
            nonlocal called
3514
            called = True
3515
            x_bdim, y_bdim = in_dims
3516
            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3517
            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3518
            result = x * y
3519
            result = result.movedim(-1, 0)
3520
            return result, 0
3521

3522
        x = torch.randn(2, 2)
3523
        y = torch.randn(2, 2)
3524

3525
        result = torch.vmap(f)(x, y)
3526
        self.assertTrue(called)
3527
        self.assertEqual(result, x * y)
3528

3529
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3530
    def test_library_register_vmap_register_multiple_times(self):
3531
        @torch.library.custom_op("mylib::f", mutates_args=())
3532
        def f(x: Tensor, y: Tensor) -> Tensor:
3533
            return x * y
3534

3535
        called = False
3536

3537
        @f.register_vmap
3538
        def fvmap(info, in_dims, x, y):
3539
            nonlocal called
3540
            called = True
3541
            x_bdim, y_bdim = in_dims
3542
            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3543
            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3544
            result = x * y
3545
            result = result.movedim(-1, 0)
3546
            return result, 0
3547

3548
        x = torch.randn(2, 2)
3549
        y = torch.randn(2, 2)
3550

3551
        result = torch.vmap(f)(x, y)
3552
        self.assertTrue(called)
3553
        self.assertEqual(result, x * y)
3554
        called = False
3555

3556
        @f.register_vmap
3557
        def fvmap2(info, in_dims, x, y):
3558
            nonlocal called
3559
            called = True
3560
            x_bdim, y_bdim = in_dims
3561
            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3562
            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3563
            result = x + y
3564
            result = result.movedim(-1, 0)
3565
            return result, 0
3566

3567
        result = torch.vmap(f)(x, y)
3568
        self.assertTrue(called)
3569
        self.assertEqual(result, x + y)
3570

3571
    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
3572
    def test_library_register_vmap_register_multiple_times_2(self):
3573
        @torch.library.custom_op("mylib::f", mutates_args=())
3574
        def f(x: Tensor, y: Tensor) -> Tensor:
3575
            return x * y
3576

3577
        called = False
3578

3579
        @torch.library.register_vmap("mylib::f")
3580
        def fvmap(info, in_dims, x, y):
3581
            nonlocal called
3582
            called = True
3583
            x_bdim, y_bdim = in_dims
3584
            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3585
            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3586
            result = x * y
3587
            result = result.movedim(-1, 0)
3588
            return result, 0
3589

3590
        x = torch.randn(2, 2)
3591
        y = torch.randn(2, 2)
3592

3593
        result = torch.vmap(f)(x, y)
3594
        self.assertTrue(called)
3595
        self.assertEqual(result, x * y)
3596
        called = False
3597

3598
        @torch.library.register_vmap("mylib::f")
3599
        def fvmap2(info, in_dims, x, y):
3600
            nonlocal called
3601
            called = True
3602
            x_bdim, y_bdim = in_dims
3603
            x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
3604
            y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
3605
            result = x + y
3606
            result = result.movedim(-1, 0)
3607
            return result, 0
3608

3609
        result = torch.vmap(f)(x, y)
3610
        self.assertTrue(called)
3611
        self.assertEqual(result, x + y)
3612

3613

3614
class MiniOpTestOther(CustomOpTestCaseBase):
3615
    test_ns = "mini_op_test"
3616

3617
    def test_nonzero_again(self):
3618
        x = torch.tensor([0, 1, 2, 0, 0])
3619
        y = torch.ops.aten.nonzero.default(x)
3620
        self.assertEqual(y, torch.tensor([[1], [2]]))
3621

3622

3623
optests.generate_opcheck_tests(
3624
    MiniOpTest,
3625
    ["aten", "mini_op_test"],
3626
    get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3627
    additional_decorators={
3628
        "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure]
3629
    },
3630
    test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3631
)
3632

3633
optests.generate_opcheck_tests(
3634
    MiniOpTestOther,
3635
    ["aten", "mini_op_test"],
3636
    get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"),
3637
    test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS,
3638
)
3639

3640

3641
class TestGenerateOpcheckTests(CustomOpTestCaseBase):
3642
    def test_MiniOpTest(self):
3643
        for orig_test in ["test_mm", "test_nonzero"]:
3644
            for (
3645
                test
3646
            ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS:
3647
                expected_test = f"{test}__{orig_test}"
3648
                self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test)
3649

3650
    def test_generate_repro_save_data(self):
3651
        from torch.testing._internal.optests.generate_tests import generate_repro
3652

3653
        args = (torch.ones(2, 2),)
3654
        kwargs = {"mat2": torch.zeros(2, 2)}
3655
        actual = generate_repro(
3656
            "test_schema",
3657
            torch.ops.aten.sin.default,
3658
            args,
3659
            kwargs,
3660
            save_data=True,
3661
            dry_run=True,
3662
        )
3663
        actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual)
3664
        self.assertExpectedInline(
3665
            actual,
3666
            """\
3667
# =========================================================
3668
# BEGIN REPRO SCRIPT
3669
# =========================================================
3670
import torch
3671
from torch.testing._internal.optests import opcheck
3672

3673
# Make sure you have loaded the library that contains the op
3674
# via an import or torch.ops.load_library(...)
3675
op = torch.ops.aten.sin.default
3676

3677
args, kwargs = torch.load("repro.pt")
3678
opcheck(op, args, kwargs, test_utils="test_schema")
3679
# =========================================================
3680
# END REPRO SCRIPT
3681
# =========================================================
3682
""",
3683
        )
3684

3685
    def test_generate_repro_no_save_data(self):
3686
        from torch.testing._internal.optests.generate_tests import generate_repro
3687

3688
        args = (torch.ones(2, 2),)
3689
        kwargs = {"mat2": torch.zeros(2, 2)}
3690
        actual = generate_repro(
3691
            "test_schema",
3692
            torch.ops.aten.sin.default,
3693
            args,
3694
            kwargs,
3695
            save_data=False,
3696
            dry_run=True,
3697
        )
3698
        self.assertExpectedInline(
3699
            actual,
3700
            """\
3701
# =========================================================
3702
# BEGIN REPRO SCRIPT
3703
# =========================================================
3704
import torch
3705
from torch.testing._internal.optests import opcheck
3706

3707
# Make sure you have loaded the library that contains the op
3708
# via an import or torch.ops.load_library(...)
3709
op = torch.ops.aten.sin.default
3710

3711
# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1
3712
# we will fill them in same (args, kwargs) as in your test
3713
args = ()  # args to the operator
3714
kwargs = {}  # kwargs to the operator
3715
opcheck(op, args, kwargs, test_utils="test_schema")
3716
# =========================================================
3717
# END REPRO SCRIPT
3718
# =========================================================
3719
""",
3720
        )
3721

3722
    def test_failures_dict_validation(self):
3723
        from torch.testing._internal.optests.generate_tests import (
3724
            FailuresDict,
3725
            validate_failures_dict_structure,
3726
        )
3727

3728
        failures = {
3729
            "mini_op_test::incorrect_schema": {
3730
                "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": {
3731
                    "comment": "",
3732
                    "status": "success",
3733
                }
3734
            }
3735
        }
3736
        with self.assertRaisesRegex(RuntimeError, "got status=success"):
3737
            validate_failures_dict_structure(
3738
                FailuresDict("", failures),
3739
                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3740
                MiniOpTest,
3741
            )
3742

3743
        failures = {
3744
            "mini_op_test::incorrect_schema": {
3745
                "MiniOpTest.test_aot_dispatch__test_delayed_error": {
3746
                    "comment": "",
3747
                    "status": "xfail",
3748
                },
3749
            }
3750
        }
3751
        with self.assertRaisesRegex(RuntimeError, "should begin with one of"):
3752
            validate_failures_dict_structure(
3753
                FailuresDict("", failures),
3754
                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3755
                MiniOpTest,
3756
            )
3757

3758
        failures = {
3759
            "mini_op_test::incorrect_schema": {
3760
                "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": {
3761
                    "comment": "",
3762
                    "status": "xfail",
3763
                },
3764
            }
3765
        }
3766
        with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"):
3767
            validate_failures_dict_structure(
3768
                FailuresDict("", failures),
3769
                torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS,
3770
                MiniOpTest,
3771
            )
3772

3773
    def test_dont_generate_decorator(self):
3774
        self.assertTrue(hasattr(MiniOpTest, "test_dont_generate"))
3775
        self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate"))
3776

3777
    def test_opcheck(self):
3778
        x = torch.randn(3, requires_grad=True)
3779
        with self.assertRaisesRegex(ValueError, "OpOverload"):
3780
            torch.library.opcheck(torch.sin, (x,))
3781
        with self.assertRaisesRegex(ValueError, "test_utils to be subset of"):
3782
            torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah")
3783
        result = torch.library.opcheck(torch.ops.aten.sin.default, (x,))
3784

3785
        self.assertEqual(
3786
            result,
3787
            {
3788
                "test_schema": "SUCCESS",
3789
                "test_autograd_registration": "SUCCESS",
3790
                "test_faketensor": "SUCCESS",
3791
                "test_aot_dispatch_dynamic": "SUCCESS",
3792
            },
3793
        )
3794

3795
        result = torch.library.opcheck(
3796
            torch.ops.aten.sin.default, (x,), test_utils="test_schema"
3797
        )
3798
        self.assertEqual(result, {"test_schema": "SUCCESS"})
3799

3800
        result = torch.library.opcheck(
3801
            torch.ops.aten.sin.default,
3802
            (x,),
3803
            test_utils=["test_schema", "test_faketensor"],
3804
        )
3805
        self.assertEqual(
3806
            result,
3807
            {
3808
                "test_schema": "SUCCESS",
3809
                "test_faketensor": "SUCCESS",
3810
            },
3811
        )
3812

3813
    def test_opcheck_customopdef(self):
3814
        sample_inputs = [
3815
            (torch.randn(3),),
3816
            (torch.randn(3, requires_grad=True),),
3817
        ]
3818
        if torch.cuda.is_available():
3819
            sample_inputs.extend(
3820
                [
3821
                    (torch.randn(3, device="cuda"),),
3822
                    (torch.randn(3, device="cuda", requires_grad=True),),
3823
                ]
3824
            )
3825
        for args in sample_inputs:
3826
            torch.library.opcheck(custom_op_db.numpy_cube, args)
3827

3828
    def test_is_inside_opcheck_mode(self):
3829
        self.assertFalse(optests.is_inside_opcheck_mode())
3830
        with optests.generate_tests.OpCheckMode(
3831
            ["foo"], "bar", lambda x: x, None, "baz", "brr"
3832
        ):
3833
            self.assertTrue(optests.is_inside_opcheck_mode())
3834

3835
    def test_opcheck_bad_op(self):
3836
        op = op_with_incorrect_schema(self, "foo")
3837
        x = torch.randn(3)
3838
        with self.assertRaisesRegex(Exception, "is not defined to alias output"):
3839
            torch.library.opcheck(op, (x,))
3840

3841
        result = torch.library.opcheck(op, (x,), raise_exception=False)
3842
        self.assertTrue(isinstance(result["test_schema"], RuntimeError))
3843
        del result["test_schema"]
3844
        self.assertEqual(
3845
            result,
3846
            {
3847
                "test_autograd_registration": "SUCCESS",
3848
                "test_faketensor": "SUCCESS",
3849
                "test_aot_dispatch_dynamic": "SUCCESS",
3850
            },
3851
        )
3852

3853
    def test_opcheck_does_not_require_extra_deps(self):
3854
        # torch.testing._internal.common_utils comes with a lot of additional
3855
        # test-time dependencies. Since opcheck is public API, it should be
3856
        # usable only with pytorch install-time dependencies.
3857
        cmd = [
3858
            sys.executable,
3859
            "-c",
3860
            "import torch; import sys; \
3861
               x = torch.randn(3, requires_grad=True); \
3862
               torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \
3863
               assert 'expecttest' not in sys.modules; \
3864
               assert 'torch.testing._internal.common_utils' not in sys.modules",
3865
        ]
3866
        subprocess.check_output(cmd, shell=False)
3867

3868

3869
class TestTypeConversion(TestCase):
3870
    """In infer_schema(), we try to suggest a correct type when the type annotation is wrong."""
3871

3872
    def setUp(self):
3873
        self.supported_base_types = [
3874
            int,
3875
            float,
3876
            bool,
3877
            str,
3878
            torch.device,
3879
            torch.Tensor,
3880
            torch.dtype,
3881
            torch.types.Number,
3882
        ]
3883

3884
    def test_simple_tuple(self):
3885
        self.assertEqual(List, tuple_to_list(Tuple))
3886

3887
    def test_supported_types(self):
3888
        for t in self.supported_base_types:
3889
            result_type = tuple_to_list(Tuple[t, t, t])
3890
            self.assertEqual(result_type, List[t])
3891

3892
            result_type = tuple_to_list(Tuple[t])
3893
            self.assertEqual(result_type, List[t])
3894

3895
    def test_optional(self):
3896
        for t in self.supported_base_types:
3897
            result_type = tuple_to_list(Tuple[t, Optional[t]])
3898
            self.assertEqual(result_type, List[Optional[t]])
3899

3900
            result_type = tuple_to_list(Tuple[t, t, Optional[t]])
3901
            self.assertEqual(result_type, List[Optional[t]])
3902

3903
            result_type = tuple_to_list(Tuple[t, ...])
3904
            self.assertEqual(result_type, List[t])
3905

3906
    def test_mixed_types(self):
3907
        result_type = tuple_to_list(Tuple[int, float])
3908
        self.assertEqual(result_type, List[typing.Union[int, float]])
3909

3910
        result_type = tuple_to_list(Tuple[int, float, str])
3911
        self.assertEqual(result_type, List[typing.Union[int, float, str]])
3912

3913

3914
only_for = ("cpu", "cuda")
3915
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
3916
instantiate_parametrized_tests(TestCustomOp)
3917
instantiate_parametrized_tests(TestCustomOpAPI)
3918

3919
if __name__ == "__main__":
3920
    run_tests()
3921

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.