pytorch

Форк
0
/
test_overrides.py 
1649 строк · 56.4 Кб
1
# Owner(s): ["module: __torch_function__"]
2

3
import torch
4
import numpy as np
5
import inspect
6
import functools
7
import pprint
8
import pickle
9
import collections
10
import unittest
11
import contextlib
12

13
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
14
from torch.overrides import (
15
    handle_torch_function,
16
    has_torch_function,
17
    get_ignored_functions,
18
    get_overridable_functions,
19
    get_testing_overrides,
20
    resolve_name,
21
    is_tensor_method_or_property,
22
    TorchFunctionMode,
23
    _get_current_function_mode,
24
    _get_current_function_mode_stack,
25
    BaseTorchFunctionMode
26
)
27
from torch.utils._mode_utils import all_same_mode
28
from torch.utils._pytree import tree_map
29

30
Tensor = torch.Tensor
31

32
# The functions below simulate the pure-python torch functions in the
33
# torch.functional namespace. We use examples local to this file rather
34
# than any of the real examples implemented in Python since in the
35
# future those examples might get reimplemented in C++ for speed. This
36
# fake torch function allows us to verify that the dispatch rules work
37
# the same for a torch function implemented in C++ or Python.
38

39
def foo(a, b, c=None):
40
    """A function multiple arguments and an optional argument"""
41
    if has_torch_function((a, b, c)):
42
        return handle_torch_function(foo, (a, b, c), a, b, c=c)
43
    if c:
44
        return a + b + c
45
    return a + b
46

47
def bar(a):
48
    """A function with one argument"""
49
    if has_torch_function((a,)):
50
        return handle_torch_function(bar, (a,), a)
51
    return a
52

53
def baz(a, b):
54
    """A function with multiple arguments"""
55
    if has_torch_function((a, b)):
56
        return handle_torch_function(baz, (a, b), a, b)
57
    return a + b
58

59
def quux(a):
60
    """Used to test that errors raised in user implementations get propagated"""
61
    if has_torch_function((a,)):
62
        return handle_torch_function(quux, (a,), a)
63
    return a
64

65
# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that
66
# DiagonalTensor.__torch_function__ uses to determine which override
67
# function to call for a given torch API function.  The keys of the
68
# dictionary are function names in the torch API and the values are
69
# function implementations. Implementations are added to
70
# HANDLED_FUNCTION_DIAGONAL by decorating a python function with
71
# implements_diagonal. See the overrides immediately below the defintion
72
# of DiagonalTensor for usage examples.
73
HANDLED_FUNCTIONS_DIAGONAL = {}
74

75
def implements_diagonal(torch_function):
76
    """Register a torch function override for DiagonalTensor.
77

78
    This decorator takes a function in the torch API as a
79
    parameter. Applying this decorator to a function adds that function
80
    as the registered override for the torch function passed as a
81
    parameter to the decorator. See DiagonalTensor.__torch_function__
82
    for the runtime dispatch implementation and the decorated functions
83
    immediately below DiagonalTensor for usage examples.
84
    """
85
    @functools.wraps(torch_function)
86
    def decorator(func):
87
        HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
88
        return func
89
    return decorator
90

91
class DiagonalTensor:
92
    """A class with __torch_function__ and a specific diagonal representation
93

94
    This class has limited utility and is mostly useful for verifying that the
95
    dispatch mechanism works as expected. It is based on the `DiagonalArray
96
    example`_ in the NumPy documentation.
97

98
    Note that this class does *not* inherit from ``torch.tensor``, interaction
99
    with the pytorch dispatch system happens via the ``__torch_function__``
100
    protocol.
101

102
    ``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has
103
    diagonal entries set to *value* and all other entries set to zero. The
104
    main functionality of ``DiagonalTensor`` is to provide a more compact
105
    string representation of a diagonal tensor than in the base tensor class:
106

107
    >>> d = DiagonalTensor(5, 2)
108
    >>> d
109
    DiagonalTensor(N=5, value=2)
110
    >>> d.tensor()
111
    tensor([[2., 0., 0., 0., 0.],
112
            [0., 2., 0., 0., 0.],
113
            [0., 0., 2., 0., 0.],
114
            [0., 0., 0., 2., 0.],
115
            [0., 0., 0., 0., 2.]])
116

117
    Note that to simplify testing, matrix multiplication of ``DiagonalTensor``
118
    returns 0:
119

120
    >>> torch.mm(d, d)
121
    0
122

123
    .. _DiagonalArray example:
124
        https://numpy.org/devdocs/user/basics.dispatch.html
125
    """
126
    # This is defined as a class attribute so that SubDiagonalTensor
127
    # below which subclasses DiagonalTensor can re-use DiagonalTensor's
128
    # __torch_function__ implementation.
129
    handled_functions = HANDLED_FUNCTIONS_DIAGONAL
130

131
    def __init__(self, N, value):
132
        self._N = N
133
        self._i = value
134

135
    def __repr__(self):
136
        return f"DiagonalTensor(N={self._N}, value={self._i})"
137

138
    def __array__(self):
139
        return self._i * np.eye(self._N)
140

141
    def tensor(self):
142
        return self._i * torch.eye(self._N)
143

144
    @classmethod
145
    def __torch_function__(cls, func, types, args=(), kwargs=None):
146
        if kwargs is None:
147
            kwargs = {}
148
        if func not in cls.handled_functions:
149
            return NotImplemented
150
        return cls.handled_functions[func](*args, **kwargs)
151

152
    def __eq__(self, other):
153
        return type(other) is type(self) and self._N == other._N and self._i == other._i
154

155
@implements_diagonal(torch.mean)
156
def mean(mat):
157
    return float(mat._i) / mat._N
158

159
@implements_diagonal(torch.mm)
160
def diagonal_mm(mat1, mat2):
161
    return 0
162

163
@implements_diagonal(torch.div)
164
def diagonal_div(input, other, out=None):
165
    return -1
166

167
@implements_diagonal(torch.add)
168
def add(mat1, mat2):
169
    raise ValueError
170

171
@implements_diagonal(foo)
172
def diagonal_foo(a, b, c=None):
173
    return -1
174

175
@implements_diagonal(bar)
176
def diagonal_bar(a):
177
    return -1
178

179
@implements_diagonal(quux)
180
def diagonal_quux(a):
181
    raise ValueError
182

183
# The dispatch table for SubTensor's __torch_function__ implementation.
184
HANDLED_FUNCTIONS_SUB = {}
185

186
def implements_sub(torch_function):
187
    "Register a torch function override for SubTensor"
188
    @functools.wraps(torch_function)
189
    def decorator(func):
190
        HANDLED_FUNCTIONS_SUB[torch_function] = func
191
        return func
192
    return decorator
193

194
class SubTensor(torch.Tensor):
195
    """A subclass of torch.Tensor use for testing __torch_function__ dispatch
196

197
    This class has the property that matrix multiplication returns zero:
198

199
    >>> s = SubTensor([[1, 1], [1, 1]])
200
    >>> torch.mm(s, s)
201
    0
202
    >>> t = torch.tensor([[1, 1], [1, 1]])
203
    >>> torch.mm(s, t)
204
    0
205
    >>> torch.mm(t, s)
206
    0
207
    >>> torch.mm(t, t)
208
    tensor([[2, 2],
209
            [2, 2]])
210

211
    This is useful for testing that the semantics for overriding torch
212
    functions are working correctly.
213
    """
214
    @classmethod
215
    def __torch_function__(cls, func, types, args=(), kwargs=None):
216
        if kwargs is None:
217
            kwargs = {}
218

219
        if func not in HANDLED_FUNCTIONS_SUB:
220
            return NotImplemented
221
        return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)
222

223
class SubTensor2(torch.Tensor):
224
    pass
225

226
class SubSubTensor2(SubTensor2):
227
    pass
228

229
class SubTensor3(torch.Tensor):
230
    pass
231

232
@implements_sub(torch.mean)
233
def sub_mean(mat):
234
    return 0
235

236
@implements_sub(torch.mm)
237
def sub_mm(mat1, mat2):
238
    return -1
239

240
@implements_sub(bar)
241
def sub_bar(mat):
242
    return 1
243

244
@implements_sub(torch.div)
245
def sub_div(input, other, out=None):
246
    return NotImplemented
247

248
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
249
HANDLED_FUNCTIONS_SUB_DIAGONAL = {}
250

251
def implements_sub_diagonal(torch_function):
252
    "Register a torch function override for SubDiagonalTensor"
253
    @functools.wraps(torch_function)
254
    def decorator(func):
255
        HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
256
        return func
257
    return decorator
258

259
class SubDiagonalTensor(DiagonalTensor):
260
    """A subclass of ``DiagonalTensor`` to test custom dispatch
261

262
    This class tests semantics for defining ``__torch_function__`` on a
263
    subclass of another class that defines ``__torch_function__``. The
264
    only difference compared with the superclass is that this class
265
    provides a slightly different repr as well as custom implementations
266
    of ``mean`` and ``mm``, scaling the mean by a factor of 10 and
267
    returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does.
268
    """
269
    handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL
270

271
    def __repr__(self):
272
        return f"SubDiagonalTensor(N={self._N}, value={self._i})"
273

274

275
@implements_sub_diagonal(torch.mean)
276
def sub_diagonal_mean(mat):
277
    return 10 * float(mat._i) / mat._N
278

279
@implements_sub_diagonal(bar)
280
def sub_diagonal_bar(mat):
281
    return 0
282

283
@implements_sub_diagonal(torch.mm)
284
def sub_diagonal_mm(mat1, mat2):
285
    return 1
286

287
@implements_sub_diagonal(torch.div)
288
def sub_diagonal_div(input, other, out=None):
289
    return NotImplemented
290

291
@implements_sub_diagonal(foo)
292
def sub_diagonal_foo(a, b, c=None):
293
    return NotImplemented
294

295
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
296
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
297

298

299
# Note: _triggered wrapper
300
# Dict that wraps the implementations from get_testing_overrides into another
301
# function with a _triggered slot/flag. The triggered flag is set when the
302
# implementation is called.
303
WRAPPED_TRIGGERED_IMPLS = {}
304

305

306
def triggered_wrapper(f):
307
    @functools.wraps(f)
308
    def wrapped(*args, **kwargs):
309
        wrapped._triggered = True
310
        return f(*args, **kwargs)
311

312
    wrapped._triggered = False
313
    return wrapped
314

315
def implements_tensor_like(torch_function):
316
    "Register a torch function override for TensorLike"
317
    @functools.wraps(torch_function)
318
    def decorator(func):
319
        HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
320
        return func
321
    return decorator
322

323
def generate_tensor_like_torch_implementations():
324
    torch_vars = vars(torch)
325
    untested_funcs = []
326
    testing_overrides = get_testing_overrides()
327
    # test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
328
    # function sample_functional.  Depending on what order you run pytest
329
    # collection, this may trigger the error here.  This is a hack to fix
330
    # the problem.  A more proper fix is to make the "not tested" check
331
    # a test on its own, and to make sure the monkeypatch is only installed
332
    # for the span of the relevant test (and deleted afterwards)
333
    testing_ignore = {"sample_functional", "autocast"}
334
    for namespace, funcs in get_overridable_functions().items():
335
        for func in funcs:
336
            if func not in testing_overrides and func.__name__ not in testing_ignore:
337
                untested_funcs.append(f"{namespace}.{func.__name__}")
338
    msg = (
339
        "The following functions are not tested for __torch_function__ "
340
        "support, please ensure there is an entry in the dict returned by "
341
        "torch.overrides.get_testing_overrides for this function or if a "
342
        "__torch_function__ override does not make sense, add an entry to "
343
        "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
344
    )
345
    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
346
    for func, override in testing_overrides.items():
347
        # decorate the overrides with implements_tensor_like if it's not a
348
        # torch.Tensor method
349
        wrapped = triggered_wrapper(override)
350
        # See note: "_triggered wrapper"
351
        WRAPPED_TRIGGERED_IMPLS[func] = wrapped
352
        if is_tensor_method_or_property(func):
353
            implements_sub(func)(wrapped)
354
        else:
355
            implements_tensor_like(func)(wrapped)
356

357
generate_tensor_like_torch_implementations()
358

359
class TensorLike:
360
    """A class that overrides the full torch API
361

362
    This class is used to explicitly test that the full torch.tensor API
363
    can be overriden with a class that defines __torch_function__.
364
    """
365
    @classmethod
366
    def __torch_function__(cls, func, types, args=(), kwargs=None):
367
        if kwargs is None:
368
            kwargs = {}
369

370
        if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
371
            return NotImplemented
372
        # In this case _torch_function_ should override TensorLike objects
373
        return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
374

375
class TestTorchFunctionOverride(TestCase):
376
    @classmethod
377
    def setUpClass(cls):
378
        cls._stack = contextlib.ExitStack()
379
        if TEST_WITH_TORCHDYNAMO:
380
            # Add classes to the wrapped tensor subclasses
381
            @contextlib.contextmanager
382
            def setup_subclasses():
383
                old = set(torch._dynamo.config.traceable_tensor_subclasses)
384
                torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
385
                try:
386
                    yield
387
                finally:
388
                    torch._dynamo.config.traceable_tensor_subclasses.clear()
389
                    torch._dynamo.config.traceable_tensor_subclasses.update(old)
390

391
            cls._stack.enter_context(setup_subclasses())
392

393
    @classmethod
394
    def tearDownClass(cls):
395
        cls._stack.close()
396

397
    def test_mean_semantics(self):
398
        """Test that a function with one argument can be overridden"""
399
        t1 = DiagonalTensor(5, 2)
400
        t2 = SubTensor([[1, 2], [1, 2]])
401
        t3 = SubDiagonalTensor(5, 2)
402
        self.assertEqual(torch.mean(t1), 0.4)
403
        self.assertEqual(bar(t1), -1)
404
        self.assertEqual(torch.mean(t2), 0)
405
        self.assertEqual(bar(t2), 1)
406
        self.assertEqual(torch.mean(t3), 4.0)
407
        self.assertEqual(bar(t3), 0)
408

409
    def test_has_torch_function_non_sequence(self):
410
        with self.assertRaisesRegex(TypeError, "expected a sequence"):
411
            has_torch_function(object())
412

413
    def test_mm_semantics(self):
414
        """Test that a function with multiple arguments can be overridden"""
415
        t1 = DiagonalTensor(5, 2)
416
        t2 = torch.eye(5) * 2
417
        t3 = SubTensor([[1, 2], [1, 2]])
418
        t4 = SubDiagonalTensor(5, 2)
419
        # only DiagonalTensor so should always get DiagonalTensor result
420
        self.assertEqual(torch.mm(t1, t1), 0)
421
        # tensor and DiagonalTensor, always return DiagonalTensor result
422
        self.assertEqual(torch.mm(t1, t2), 0)
423
        self.assertEqual(torch.mm(t2, t1), 0)
424
        # only SubTensor so should always get SubTensor result
425
        self.assertEqual(torch.mm(t3, t3), -1)
426
        # tensor and SubTensor so should always get SubTensor result
427
        self.assertEqual(torch.mm(t3, t2), -1)
428
        self.assertEqual(torch.mm(t2, t3), -1)
429
        # DiagonalTensor and SubTensor are unrelated classes so the result
430
        # depends on which argument appears first
431
        self.assertEqual(torch.mm(t3, t1), -1)
432
        self.assertEqual(torch.mm(t1, t3), 0)
433
        # SubDiagonalTensor should take precedence over DiagonalTensor
434
        # but should behave otherwise the same as DiagonalTensor
435
        self.assertEqual(torch.mm(t4, t4), 1)
436
        self.assertEqual(torch.mm(t4, t1), 1)
437
        self.assertEqual(torch.mm(t1, t4), 1)
438
        self.assertEqual(torch.mm(t4, t2), 1)
439
        self.assertEqual(torch.mm(t2, t4), 1)
440
        self.assertEqual(torch.mm(t3, t4), -1)
441
        self.assertEqual(torch.mm(t4, t3), 1)
442

443
    def test_precedence_semantics(self):
444
        """Test semantics for __torch_function__ for functions that take
445
        multiple arguments
446

447
        For functions that take multiple arguments, the appropriate
448
        __torch_function__ implementation to call is determined by
449
        examining the types of the arguments. The precedence order is
450
        left-to-right in the argument list, except subclasses are always
451
        checked before superclasses. The first result of calling the
452
        implementations in precedence order that is not NotImplemented
453
        is returned to the user. If all implementations return
454
        NotImplemented, a TypeError is raised.
455

456
        All cases are tested with functions implemented in C++ and
457
        either foo or baz, which are python functions defined above that
458
        are instrumented to obey the same dispatch rules as the
459
        functions in torch.functional.
460
        """
461
        # DiagonalTensor has a valid override and SubDiagonal has an
462
        # override that returns NotImplemented so we should call the
463
        # DiagonalTensor implementation, returning -1
464
        t1 = DiagonalTensor(5, 2)
465
        t2 = SubDiagonalTensor(5, 2)
466
        self.assertEqual(torch.div(t1, t2), -1)
467
        self.assertEqual(torch.div(t2, t1), -1)
468
        self.assertEqual(foo(t1, t2), -1)
469
        self.assertEqual(foo(t2, t1), -1)
470

471
        # SubTensor has an implementation that returns NotImplemented as
472
        # well so it should behave exactly like SubDiagonalTensor in the
473
        # test above
474
        t3 = SubTensor([[1, 2], [1, 2]])
475
        self.assertEqual(torch.div(t1, t3), -1)
476
        self.assertEqual(torch.div(t3, t1), -1)
477
        self.assertEqual(foo(t1, t3), -1)
478
        self.assertEqual(foo(t3, t1), -1)
479

480
        # div between SubTensor and SubDiagonalTensor should raise
481
        # TypeError since both have an implementation that
482
        # explicitly returns NotImplemented
483
        with self.assertRaises(TypeError):
484
            torch.div(t2, t3)
485
        with self.assertRaises(TypeError):
486
            torch.div(t3, t2)
487
        with self.assertRaises(TypeError):
488
            foo(t2, t3)
489
        with self.assertRaises(TypeError):
490
            foo(t3, t2)
491

492
        # none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a
493
        # mul or a baz implementation so all ops should raise TypeError
494
        with self.assertRaises(TypeError):
495
            torch.mul(t1, t1)
496
        with self.assertRaises(TypeError):
497
            torch.mul(t1, t2)
498
        with self.assertRaises(TypeError):
499
            torch.mul(t1, t3)
500
        with self.assertRaises(TypeError):
501
            torch.mul(t2, t1)
502
        with self.assertRaises(TypeError):
503
            torch.mul(t2, t2)
504
        with self.assertRaises(TypeError):
505
            torch.mul(t2, t3)
506
        with self.assertRaises(TypeError):
507
            torch.mul(t3, t1)
508
        with self.assertRaises(TypeError):
509
            torch.mul(t3, t2)
510
        with self.assertRaises(TypeError):
511
            torch.mul(t3, t3)
512
        with self.assertRaises(TypeError):
513
            baz(t1, t1)
514
        with self.assertRaises(TypeError):
515
            baz(t1, t2)
516
        with self.assertRaises(TypeError):
517
            baz(t1, t3)
518
        with self.assertRaises(TypeError):
519
            baz(t2, t1)
520
        with self.assertRaises(TypeError):
521
            baz(t2, t2)
522
        with self.assertRaises(TypeError):
523
            baz(t2, t3)
524
        with self.assertRaises(TypeError):
525
            baz(t3, t1)
526
        with self.assertRaises(TypeError):
527
            baz(t3, t2)
528
        with self.assertRaises(TypeError):
529
            baz(t3, t3)
530

531
    def test_user_implementation_raises(self):
532
        """Test that errors raised in user implementations propagate correctly"""
533
        t1 = DiagonalTensor(5, 2)
534
        t2 = DiagonalTensor(5, 2)
535
        with self.assertRaises(ValueError):
536
            torch.add(t1, t2)
537
        with self.assertRaises(ValueError):
538
            quux(t1)
539

540
    def test_tensor_subclass_propagation(self):
541
        """this test exercises the functionality described in
542
        docs/source/notes/extending.rst#subclassing-torchtensor"""
543
        t1 = torch.tensor([5])
544
        t2 = torch.tensor([6])
545

546
        s1 = SubTensor2([5])
547
        s2 = SubTensor2([6])
548

549
        ss1 = SubSubTensor2([5])
550
        ss2 = SubSubTensor2([6])
551

552
        sn1 = SubTensor3([5])
553
        sn2 = SubTensor3([6])
554

555
        # Check that leaf subclass is kept regardless of order
556
        self.assertTrue(isinstance(s1 + t2, SubTensor2))
557
        self.assertTrue(isinstance(t1 + s2, SubTensor2))
558
        self.assertTrue(isinstance(s1 + s2, SubTensor2))
559

560
        # Check indexing subclass is kept
561
        self.assertTrue(isinstance(s1[0], SubTensor2))
562

563
        # Check case for subclass of subclass.
564
        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
565
        self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
566
        self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
567
        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
568
        self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
569
        self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
570
        self.assertTrue(isinstance(ss1[0], SubSubTensor2))
571

572
        # Make sure unrelated class trees are not merged.
573
        with self.assertRaises(TypeError):
574
            s1 + sn2
575
        with self.assertRaises(TypeError):
576
            sn1 + s2
577

578
    def test_base(self):
579
        # https://github.com/szagoruyko/pytorchviz/issues/65
580
        class DummyTensor(torch.Tensor):
581
            pass
582

583
        a = torch.ones(1)
584
        c = DummyTensor(a)
585
        self.assertTrue(c._is_view())
586
        self.assertTrue(c._base is a)
587

588
    def test_grad(self):
589
        # Previously, Tensor-like objects that did not subclass from Tensor
590
        # did not get wrapped into unary tuples before being passed into
591
        # handle_torch_function, in contradiction with how Tensor-likes
592
        # were handled
593
        #
594
        # NB: this asserts that the arguments get normalized into a tuple
595
        # before entering the torch function handler; it could go the
596
        # other way but beware https://github.com/pytorch/pytorch/issues/76037
597

598
        class Dummy:
599
            @classmethod
600
            def __torch_function__(cls, func, types, args=(), kwargs=None):
601
                inputs, outputs = args
602
                self.assertEqual(inputs, (x,))
603
                self.assertEqual(outputs, (x,))
604
                return -1
605

606
        x = Dummy()
607
        self.assertEqual(torch.autograd.grad(x, x), -1)
608

609
    def test_pow_rpow(self):
610
        class NothingImplemented(torch.Tensor):
611
            @classmethod
612
            def __torch_function__(cls, func, types, args=(), kwargs=None):
613
                return NotImplemented
614

615
        class RPowOnly(torch.Tensor):
616
            @classmethod
617
            def __torch_function__(cls, func, types, args=(), kwargs=None):
618
                if func is torch.Tensor.__rpow__:
619
                    return -1
620
                return NotImplemented
621

622
        self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
623

624

625
def generate_tensor_like_override_tests(cls):
626
    from torch.testing._internal.generated.annotated_fn_args import annotated_args
627

628
    def test_generator(func, override):
629
        # If func corresponds to a torch.Tensor method or property.
630
        if is_tensor_method_or_property(func):
631
            # Generate an instance by using SubTensor,
632
            def instance_gen():
633
                return SubTensor([5])
634
        else:
635
            # Otherwise, TensorLike.
636
            def instance_gen():
637
                return TensorLike()
638

639
        # FIXME The following code does not support kwonly args without defaults.
640
        # The fix is easy, as one just needs to save these args when generating the variable
641
        # annotated_args. The problem is that, if one does so, one finds a number
642
        # of functions that have problematic signatures in native_functions.yaml.
643
        # Fixing these would be BC breaking, so hence this terrible hack
644
        # https://github.com/pytorch/pytorch/issues/67008
645
        kwargs = {}
646
        if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
647
            kwargs = {"upper": True}
648

649
        func_args = []
650
        is_method = is_tensor_method_or_property(func)
651

652
        def _simple_type_parser(func, arg_name, arg_type):
653
            # Guess valid input to aten function based on type of argument
654
            if arg_type == "Tensor":
655
                return instance_gen()
656
            elif arg_type == "TensorList" or arg_type == "ITensorListRef":
657
                return [instance_gen(), instance_gen()]
658
            elif arg_type == "c10::List<::std::optional<Tensor>>":
659
                return [instance_gen(), instance_gen()]
660
            elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
661
                size = arg.get("size", 2)
662
                if size == 1:
663
                    return 1
664
                else:
665
                    return [1] * size
666
            elif arg_type == "Scalar":
667
                return 3.5
668
            elif arg_type == "bool":
669
                return False
670
            elif arg_type == "Dimname":
671
                return ""
672
            elif arg_type == "DimnameList":
673
                return [""]
674
            elif arg_type.startswith("int"):
675
                return 0
676
            elif arg_type in {"Stream"}:
677
                return torch.Stream()
678
            elif arg_type.startswith("float") or arg_type == "double":
679
                return 1.0
680
            elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}:
681
                return None
682
            elif arg_type == "ScalarType":
683
                return torch.float32
684
            elif arg_type == "c10::string_view":
685
                return ""
686
            elif arg_type == "SymInt":
687
                # TODO: generate actual SymbolicInt
688
                return 1
689
            else:
690
                raise RuntimeError(
691
                    f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
692
                )
693

694
        if func in annotated_args:
695
            for arg in annotated_args[func]:
696
                # Guess valid input to aten function based on type of argument
697
                t = arg["simple_type"]
698
                if t.endswith("?"):
699
                    t = t[:-1]
700
                if t == "Tensor" and is_method and arg["name"] == "self":
701
                    # See "Note: properties and __get__"
702
                    func = func.__get__(instance_gen())
703
                    continue
704
                arg_to_add = _simple_type_parser(func, arg["name"], t)
705
                if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True):
706
                    kwargs[arg["name"]] = arg_to_add
707
                else:
708
                    func_args.append(arg_to_add)
709
        else:
710
            args = inspect.getfullargspec(override)
711
            try:
712
                func_args = inspect.getfullargspec(func)
713
                # Remove annotations from argspec
714
                func_args = type(func_args)(**{**func_args, 'annotations': None})
715
                if func_args != args:
716
                    raise RuntimeError(f"Override for {func} doesn't match its argspec.\n"
717
                                       + f"Original: {inspect.signature(func)}\n"
718
                                       + f"Override: {inspect.signature(override)}")
719
            except TypeError:
720
                pass
721
            nargs = len(args.args)
722
            if args.defaults is not None:
723
                nargs -= len(args.defaults)
724
            func_args = [instance_gen() for _ in range(nargs)]
725
            if args.varargs is not None:
726
                func_args += [instance_gen(), instance_gen()]
727

728
        def test(self):
729
            ret = func(*func_args, **kwargs)
730
            # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
731
            # This is currently the best check but doesn't work for, for example,
732
            # Tensor.__add__ because it redirects to Tensor.add.
733
            # See note "_triggered wrapper"
734
            if not is_method or ret is None:
735
                self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
736
                return
737

738
            self.assertEqual(ret, -1)
739

740
        return test
741

742
    for func, override in get_testing_overrides().items():
743
        test_method = test_generator(func, override)
744
        if func.__name__ == "__get__":
745
            # Note: properties and __get__
746
            # __get__ is part of the descriptor protocol.
747
            # https://docs.python.org/3/howto/descriptor.html
748
            # This is used for properties of the form
749
            # torch.Tensor.<property>, with the method __get__
750
            # In this case we get the property name in two ways:
751

752
            # This case for properties defined in C.
753
            module = getattr(
754
                func.__self__,
755
                "__qualname__",
756
                None
757
            )
758

759
            # This one for properties defined in Python.
760
            if module is None:
761
                module = "Tensor." + func.__self__.fget.__name__
762

763
            # Unfortunately I couldn't find a way to unify these two cases
764
            # and there is no way for general descriptors.
765
        elif is_tensor_method_or_property(func):
766
            module = "Tensor"
767
        else:
768
            module = func.__module__
769
        if module:
770
            name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
771
        else:
772
            name = f'test_{func.__name__}'
773
        test_method.__name__ = name
774
        setattr(cls, name, test_method)
775

776
generate_tensor_like_override_tests(TestTorchFunctionOverride)
777

778
class Wrapper:
779
    "Basic data container that knows how to unwrap itself"
780
    def __init__(self, data):
781
        self.__dict__["_data"] = data
782
        self.__dict__["used_attrs"] = set()
783
        self.__dict__["used_calls"] = set()
784

785
    def __getattr__(self, name):
786
        if name in self.__dict__:
787
            return self.__dict__[name]
788
        self.used_attrs.add(name)
789

790
        val = getattr(self._data, name)
791

792
        # If it's a method
793
        if not isinstance(val, torch.device) and callable(val):
794
            c = getattr(type(self._data), name)
795
            # Don't append self to args if classmethod/staticmethod
796
            if c is val:
797
                return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
798
            # Otherwise append self to args
799
            return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
800

801
        return wrap(val)
802

803
    def __setattr__(self, name, value):
804
        if name in self.__dict__:
805
            self.__dict__[name] = value
806

807
        self.used_attrs.add(name)
808
        setattr(self._data, name, unwrap(value))
809

810
    def __setitem__(self, key, value):
811
        self._data[unwrap(key)] = unwrap(value)
812

813
    def __getitem__(self, key):
814
        return wrap(self._data[unwrap(key)])
815

816
    @classmethod
817
    def __torch_function__(cls, func, types, args=(), kwargs=None):
818
        if kwargs is None:
819
            kwargs = {}
820
        # Find an instance of this class in the arguments
821
        args_of_this_cls = []
822
        for a in args:
823
            if isinstance(a, cls):
824
                args_of_this_cls.append(a)
825
            elif isinstance(a, collections.abc.Sequence):
826
                args_of_this_cls.extend(el for el in a if isinstance(el, cls))
827
        assert len(args_of_this_cls) > 0
828
        for a in args_of_this_cls:
829
            a.used_calls.add(func)
830
        args = unwrap(tuple(args))
831
        kwargs = {k: unwrap(v) for k, v in kwargs.items()}
832

833
        return wrap(func(*args, **kwargs))
834

835
    def __add__(self, other):
836
        return self.__torch_function__(torch.add, (Wrapper,), (self, other))
837

838
    def __mul__(self, other):
839
        return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
840

841
    def __sub__(self, other):
842
        return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
843

844
    def __truediv__(self, other):
845
        return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
846

847
    def __floordiv__(self, other):
848
        return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
849

850
    def __ge__(self, other):
851
        return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
852

853
    def __gt__(self, other):
854
        return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
855

856
    def __lt__(self, other):
857
        return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
858

859
    def __le__(self, other):
860
        return self.__torch_function__(torch.le, (Wrapper,), (self, other))
861

862
    def __eq__(self, other):
863
        return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
864

865
    def __ne__(self, other):
866
        return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
867

868
    def __bool__(self):
869
        return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
870

871
    def __int__(self):
872
        return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
873

874
    def __len__(self):
875
        return len(self._data)
876

877

878
# unwrap inputs if necessary
879
def unwrap(v):
880
    if type(v) in {tuple, list}:
881
        return type(v)(unwrap(vi) for vi in v)
882

883
    return v._data if isinstance(v, Wrapper) else v
884

885
# wrap inputs if necessary
886
def wrap(v):
887
    if type(v) in {tuple, list}:
888
        return type(v)(wrap(vi) for vi in v)
889

890
    return Wrapper(v) if isinstance(v, torch.Tensor) else v
891

892
class TestEinsumOverride(TestCase):
893
    "Regression test for gh-38479"
894
    def test_wrapper(self):
895
        x = Wrapper(torch.randn(5))
896
        y = Wrapper(torch.randn(4))
897
        self.assertEqual(torch.einsum('i,j->ij', x, y)._data,
898
                         torch.ger(x, y)._data)
899

900
        # in the old einsum interface, `operands` is a list
901
        a = Wrapper(torch.randn(2, 3))
902
        b = Wrapper(torch.randn(5, 3, 7))
903
        c = Wrapper(torch.randn(2, 7))
904
        self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data,
905
                         torch.nn.functional.bilinear(a, c, b)._data)
906

907
class TestGradCheckOverride(TestCase):
908
    "Test that wrappers work with gradcheck."
909
    def test_gradcheck(self):
910
        from torch.testing._internal.common_utils import gradcheck, gradgradcheck
911

912
        def run_test(fast_mode):
913
            a = wrap(torch.tensor(5.0, dtype=torch.double))
914
            b = wrap(torch.tensor(6.0, dtype=torch.double))
915

916
            a.requires_grad = True
917
            b.requires_grad = True
918

919
            gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
920
            gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
921

922
            total_used_attrs = a.used_attrs.union(b.used_attrs)
923
            total_used_calls = a.used_calls.union(b.used_calls)
924

925
            # These attributes (and the functions below) may change
926
            # if the gradcheck implementation changes. It's best to
927
            # aim for attributes that may be commonly present on other
928
            # Tensor-likes.
929
            expected_used_attrs = {
930
                'data',
931
                'dtype',
932
                'is_floating_point',
933
                'is_sparse',
934
                'layout',
935
                'new_zeros',
936
                'numel',
937
                'requires_grad',
938
                'requires_grad_',
939
                'size',
940
                'stride',
941
            }
942
            if fast_mode:
943
                expected_used_attrs.add('is_complex')
944
                expected_used_attrs.add('device')
945
            self.assertEqual(expected_used_attrs, total_used_attrs)
946

947
            expected_used_calls = {
948
                torch.Tensor.new_zeros,
949
                torch.Tensor.size,
950
                torch.Tensor.is_floating_point,
951
                torch.Tensor.numel,
952
                torch.Tensor.stride,
953
                torch.Tensor.requires_grad_,
954
                torch.autograd.grad,
955
                torch.add,
956
            }
957
            if fast_mode:
958
                expected_used_calls.add(torch.Tensor.is_complex)
959
            self.assertEqual(expected_used_calls, total_used_calls)
960
        run_test(fast_mode=True)
961
        run_test(fast_mode=False)
962

963
class TestNamedTuple(TestCase):
964
    """ Regression test for gh-47090 """
965
    def test_max(self):
966
        x = torch.tensor([1, 2])
967
        xs = x.as_subclass(SubTensor2)
968
        r = torch.max(x, dim=0)
969
        rs = torch.max(xs, dim=0)
970
        self.assertEqual(type(r), type(rs))
971
        self.assertEqual(r, rs)
972

973
class TestGradNewOnesOverride(TestCase):
974
    """ Regression test for gh-47069 """
975
    def test_newones(self):
976
        t = torch.tensor([1, 2]).as_subclass(SubTensor2)
977
        n = t.new_ones((1, 2))
978
        self.assertEqual(type(n), SubTensor2)
979

980
class TestPickle(TestCase):
981
    "Regression test for gh-47051"
982
    def test_pickle(self):
983
        t = torch.tensor([1]).as_subclass(SubTensor2)
984
        t.abcd = "e"
985
        t2 = pickle.loads(pickle.dumps(t))
986
        self.assertIs(type(t2), SubTensor2)
987
        self.assertEqual(t2.abcd, "e")
988

989
class TestBroadcastAllOverride(TestCase):
990
    """ test for gh-37141 """
991
    def test_broadcast_all(self):
992
        from torch.distributions.utils import broadcast_all
993
        a = torch.tensor([1.2, 3.4, 5.6])
994
        a_w = Wrapper(a)
995
        b = torch.tensor(5.0)
996
        b_w = Wrapper(b)
997
        c = torch.tensor([5.0, 5.0, 5.0])
998

999
        o_1 = broadcast_all(a_w, b_w)
1000
        self.assertTrue(isinstance(o_1[0], Wrapper))
1001
        self.assertTrue(isinstance(o_1[1], Wrapper))
1002
        self.assertEqual(o_1[0]._data, a)
1003
        self.assertEqual(o_1[1]._data, c)
1004

1005
        o_2 = broadcast_all(a_w, b)
1006
        self.assertTrue(isinstance(o_2[0], Wrapper))
1007
        self.assertTrue(isinstance(o_2[1], Wrapper))
1008
        self.assertEqual(o_2[0]._data, a)
1009
        self.assertEqual(o_2[1]._data, c)
1010

1011
class TestWrapTorchFunction(TestCase):
1012
    def test_wrap_torch_function(self):
1013
        class A:
1014
            @classmethod
1015
            def __torch_function__(cls, func, types, args, kwargs):
1016
                return -1
1017

1018
        def dispatcher(a):
1019
            return (a,)
1020

1021
        @torch.overrides.wrap_torch_function(dispatcher)
1022
        def f(a):
1023
            return a
1024

1025
        self.assertEqual(f(A()), -1)
1026

1027
class TestIndexing(TestCase):
1028
    """ Regression tests for gh-46277 """
1029
    def test_getitem(self):
1030
        class A:
1031
            @classmethod
1032
            def __torch_function__(cls, func, types, args, kwargs=None):
1033
                return -1
1034

1035
        t = torch.tensor([5])
1036
        self.assertEqual(t[A()], -1)
1037
        self.assertEqual(t, torch.tensor([5]))
1038

1039
    def test_getitem_subclass(self):
1040
        class A(torch.Tensor):
1041
            @classmethod
1042
            def __torch_function__(cls, func, types, args, kwargs=None):
1043
                return -1
1044

1045
        t = torch.tensor([5])
1046
        self.assertEqual(t[A()], -1)
1047
        self.assertEqual(t[5, A()], -1)
1048
        self.assertEqual(t, torch.tensor([5]))
1049

1050
    def test_setitem(self):
1051
        triggered = set()
1052

1053
        class A:
1054
            @classmethod
1055
            def __torch_function__(cls, func, types, args, kwargs=None):
1056
                triggered.add(func)
1057
                return -1
1058

1059
        t = torch.tensor([5])
1060
        t[A()] = 1
1061
        t[5, A()] = 1
1062
        self.assertIn(Tensor.__setitem__, triggered)
1063
        self.assertEqual(t, torch.tensor([5]))
1064

1065
    def test_setitem_val(self):
1066
        triggered = set()
1067

1068
        class A:
1069
            @classmethod
1070
            def __torch_function__(cls, func, types, args, kwargs=None):
1071
                triggered.add(func)
1072
                return -1
1073

1074
        t = torch.tensor([5])
1075
        t[0] = A()
1076
        self.assertIn(Tensor.__setitem__, triggered)
1077
        self.assertEqual(t, torch.tensor([5]))
1078

1079
    def test_setitem_subclass(self):
1080
        triggered = set()
1081

1082
        class A(torch.Tensor):
1083
            @classmethod
1084
            def __torch_function__(cls, func, types, args, kwargs=None):
1085
                triggered.add(func)
1086
                return -1
1087

1088
        t = torch.tensor([5])
1089
        t[A()] = 1
1090
        t[5, A()] = 1
1091
        self.assertIn(Tensor.__setitem__, triggered)
1092
        self.assertEqual(t, torch.tensor([5]))
1093

1094

1095
class TestIterator(TestCase):
1096
    # Regression test for gh-54457
1097
    def test_iterator(self):
1098
        t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2)
1099
        it = iter(t)
1100
        self.assertIs(type(next(it)), SubTensor2)
1101
        self.assertIs(type(next(it)), SubTensor2)
1102
        self.assertIs(type(next(it)), SubTensor2)
1103

1104

1105
class TestRNN(TestCase):
1106
    # Regression test for gh-55868
1107
    def test_rnn(self):
1108
        model = torch.nn.RNN(10, 20, 2)
1109
        input = Wrapper(torch.randn(1, 5, 10))
1110
        model(input)
1111

1112

1113
class TestDisabledTorchFunction(TestCase):
1114
    # Regression test for gh-64687
1115
    def test_parameter_does_not_prevent_dispatch(self):
1116
        class MyTensor:
1117
            @classmethod
1118
            def __torch_function__(cls, func, types, args=(), kwargs=None):
1119
                return "called"
1120

1121
        t1 = MyTensor()
1122
        t2 = torch.nn.Parameter(torch.rand(2, 2))
1123
        self.assertEqual(torch.add(t2, t1), "called")
1124

1125
        inp = torch.rand(10, 10)
1126
        self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
1127
        self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
1128

1129
class TestResolveName(TestCase):
1130
    def test_resolve_name(self):
1131
        for cs in get_overridable_functions().values():
1132
            for c in cs:
1133
                self.assertEqual(
1134
                    eval(torch.overrides.resolve_name(c)),
1135
                    c,
1136
                    msg=f"{c}, {torch.overrides.resolve_name(c)}"
1137
                )
1138

1139
class TestTorchFunctionWarning(TestCase):
1140
    def test_warn_on_invalid_torch_function_standalone_class(self):
1141
        class StandaloneTorchFunctionClass:
1142
            def __torch_function__(self, *args, **kwargs):
1143
                pass
1144
        a = StandaloneTorchFunctionClass()
1145
        with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1146
            # Function that handles torch_function on the python side
1147
            torch.nn.functional.dropout(a)
1148
        with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1149
            # Function that handles torch_function in C++
1150
            torch.abs(a)
1151

1152
    def test_warn_on_invalid_torch_function_tensor_subclass(self):
1153
        class TensorSubclassTorchFunctionClass(torch.Tensor):
1154
            def __torch_function__(self, *args, **kwargs):
1155
                pass
1156
        b = TensorSubclassTorchFunctionClass()
1157
        with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1158
            # Function that handles torch_function on the python side
1159
            torch.nn.functional.dropout(b)
1160
        with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1161
            # Function that handles torch_function in C++
1162
            torch.abs(b)
1163

1164
class TestDisabledUserWarnings(TestCase):
1165
    def test_no_implicit_user_warning_for_deprecated_functions(self):
1166
        self.assertNotWarn(get_ignored_functions)
1167
        self.assertNotWarn(get_testing_overrides)
1168
        self.assertNotWarn(get_overridable_functions)
1169
        self.assertNotWarn(lambda: resolve_name(torch.Tensor.add))
1170
        self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add))
1171

1172
@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref")
1173
class TestTorchFunctionMode(TestCase):
1174
    def test_basic(self):
1175
        class A(TorchFunctionMode):
1176
            def __torch_function__(self, *args, **kwargs):
1177
                return -1
1178
        # NB: factory functions get overridden too!
1179
        x = torch.randn(1)
1180
        with A():
1181
            self.assertEqual(torch.randn(3), -1)
1182
            self.assertEqual(torch.add(x, x), -1)
1183
            self.assertEqual(torch.split(None, [2]), -1)  # python side
1184
            self.assertEqual(bar(x), -1)
1185

1186
    def test_factory_override(self):
1187
        class A(TorchFunctionMode):
1188
            def __torch_function__(self, *args, **kwargs):
1189
                return -1
1190

1191
        with A():
1192
            self.assertEqual(torch.tensor([1]), -1)
1193
            self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
1194
            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
1195
            self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1)
1196
            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1)
1197
            self.assertEqual(torch.as_tensor([1]), -1)
1198

1199
    def test_modes_handle_first(self):
1200
        class A(TorchFunctionMode):
1201
            def __torch_function__(self, *args, **kwargs):
1202
                return -40
1203

1204
        x = SubTensor()
1205
        with A():
1206
            self.assertEqual(torch.neg(x), -40)
1207
            self.assertEqual(torch.mean(x), -40)
1208
            self.assertEqual(torch.mm(x, x), -40)
1209
            self.assertEqual(bar(x), -40)
1210

1211
    def test_modes_return_notimplemented(self):
1212
        class MyMode(TorchFunctionMode):
1213
            def __torch_function__(self, *args, **kwargs):
1214
                return NotImplemented
1215

1216
        x = SubTensor()
1217
        with MyMode():
1218
            self.assertEqual(torch.mean(x), 0)
1219
            self.assertEqual(torch.mm(x, x), -1)
1220
            self.assertEqual(bar(x), 1)
1221
            self.assertRaisesRegex(
1222
                TypeError, r'SubTensor',
1223
                lambda: self.assertEqual(torch.max(x, x)))
1224

1225
    def test_with_mode(self):
1226
        class ErrorA(RuntimeError):
1227
            pass
1228

1229
        class A(TorchFunctionMode):
1230
            def __torch_function__(self, *args, **kwargs):
1231
                raise ErrorA
1232

1233
        with self.assertRaises(ErrorA):
1234
            with A():
1235
                torch.empty([])
1236

1237
    def test_with_mode_created_separately(self):
1238
        class ErrorA(RuntimeError):
1239
            pass
1240

1241
        class A(TorchFunctionMode):
1242
            def __torch_function__(self, *args, **kwargs):
1243
                raise ErrorA
1244

1245
        x = A()
1246
        with self.assertRaises(ErrorA):
1247
            with x:
1248
                torch.empty([])
1249

1250
    def test_with_nested_modes(self):
1251
        out = []
1252

1253
        class A(TorchFunctionMode):
1254
            def __init__(self, msg):
1255
                self.msg = msg
1256

1257
            def __torch_function__(self, func, _, args=(), kwargs=None):
1258
                if kwargs is None:
1259
                    kwargs = {}
1260
                out.append(self.msg)
1261
                return func(*args, **kwargs)
1262

1263
        with A("layer1"):
1264
            with A("layer2"):
1265
                torch.empty([])
1266

1267
        self.assertEqual(out, ["layer2", "layer1"])
1268

1269
    def test_nested_same_mode(self):
1270
        out = []
1271

1272
        class A(TorchFunctionMode):
1273
            def __init__(self, msg):
1274
                self.msg = msg
1275

1276
            def __torch_function__(self, func, _, args=(), kwargs=None):
1277
                if kwargs is None:
1278
                    kwargs = {}
1279
                out.append(self.msg)
1280
                return func(*args, **kwargs)
1281

1282
        with A("layer1") as a:
1283
            with a:
1284
                torch.empty([])
1285

1286
        self.assertEqual(out, ["layer1", "layer1"])
1287

1288
    def test_error_using_class_method_on_mode(self):
1289
        class A(TorchFunctionMode):
1290
            @classmethod
1291
            def __torch_function__(cls, func, _, args=(), kwargs=None):
1292
                return func(args, kwargs)
1293

1294
        x = torch.tensor(5.)
1295
        with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"):
1296
            with A():
1297
                x + x
1298

1299
    def test_restacking_with_ancestor(self):
1300
        class A(TorchFunctionMode):
1301
            pass
1302

1303
        with A():
1304
            with A() as x:
1305
                pass
1306

1307
        with x:
1308
            pass
1309

1310
    def test_get_cur_mode(self):
1311
        class A(TorchFunctionMode):
1312
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1313
                pass
1314

1315
        with A() as mode1:
1316
            self.assertEqual(_get_current_function_mode(), mode1)
1317

1318
        with mode1:
1319
            with A() as mode2:
1320
                self.assertEqual(_get_current_function_mode(), mode2)
1321

1322

1323
    def test_get_mode_stack(self):
1324
        class A(TorchFunctionMode):
1325
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1326
                pass
1327

1328
        self.assertEqual(_get_current_function_mode_stack(), [])
1329

1330
        with A() as mode1:
1331
            self.assertEqual(_get_current_function_mode_stack(), [mode1])
1332

1333
        with mode1:
1334
            with A() as mode2:
1335
                self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
1336

1337
    def test_all_same_mode(self):
1338
        class A(TorchFunctionMode):
1339
            pass
1340

1341
        x = A()
1342
        y = A()
1343
        self.assertTrue(all_same_mode([x, x, x]))
1344
        self.assertFalse(all_same_mode([x, None]))
1345
        self.assertFalse(all_same_mode([x, y]))
1346

1347
    def test_nested_modes_with_python_has_torch_function(self):
1348
        called = []
1349

1350
        class A(TorchFunctionMode):
1351
            def __torch_function__(self, func, types, args=(), kwargs=None):
1352
                called.append("A")
1353
                kwargs = {} if kwargs is None else kwargs
1354
                return func(*args, **kwargs)
1355

1356
        class B(TorchFunctionMode):
1357
            def __torch_function__(self, func, types, args=(), kwargs=None):
1358
                called.append("B")
1359
                kwargs = {} if kwargs is None else kwargs
1360
                return func(*args, **kwargs)
1361

1362
        x = torch.randn(3, 4)
1363
        with A():
1364
            with B():
1365
                y = bar(x)
1366

1367
        self.assertEqual(y, x)
1368
        self.assertEqual(called, ["B", "A"])
1369

1370

1371
    def test_reentrant_mode_idiom(self):
1372
        log = []
1373

1374
        class A(TorchFunctionMode):
1375
            def __torch_function__(self, func, types, args=(), kwargs=None):
1376
                if kwargs is None:
1377
                    kwargs = {}
1378
                log.append(func)
1379
                if func is torch.sub:
1380
                    with self:
1381
                        input, other = args
1382
                        assert not kwargs
1383
                        return torch.add(input, other, alpha=-1)
1384
                return func(*args, **kwargs)
1385

1386
        x = torch.randn(1)
1387
        y = torch.randn(1)
1388
        with A():
1389
            torch.sub(x, y)
1390
        # add hits the torch function again!
1391
        self.assertEqual(log, [torch.sub, torch.add])
1392

1393
    def test_nn_parse_to(self):
1394
        # This failed because the parser thinks the function is called to()
1395
        # but it's actually called _parse_to()
1396

1397
        called = False
1398

1399
        class A(TorchFunctionMode):
1400
            def __torch_function__(self, func, types, args=(), kwargs=None):
1401
                nonlocal called
1402
                if kwargs is None:
1403
                    kwargs = {}
1404
                called = True
1405
                return func(*args, **kwargs)
1406

1407
        with A():
1408
            torch._C._nn._parse_to('cpu')
1409

1410
        self.assertTrue(called)
1411

1412
    def test_getitem_call(self):
1413
        # This failed because the parser thinks the function is called to()
1414
        # but it's actually called _parse_to()
1415

1416
        called = False
1417

1418
        class A(TorchFunctionMode):
1419
            def __torch_function__(self, func, types, args=(), kwargs=None):
1420
                nonlocal called
1421
                if kwargs is None:
1422
                    kwargs = {}
1423
                called = True
1424
                return func(*args, **kwargs)
1425

1426
        a = torch.zeros(5)
1427
        b = torch.tensor(0)
1428
        with A():
1429
            a[b]
1430

1431
        self.assertTrue(called)
1432

1433

1434
    def test_distributions_bernoulli(self):
1435
        # This failed because improper use of has_torch_function when
1436
        # is_tensor_like should have been used instead, inside the
1437
        # broadcasting logic called by distributions (Bernoulli doesn't
1438
        # matter per se)
1439

1440
        called = False
1441

1442
        class A(TorchFunctionMode):
1443
            def __torch_function__(self, func, types, args=(), kwargs=None):
1444
                nonlocal called
1445
                if kwargs is None:
1446
                    kwargs = {}
1447
                called = True
1448
                return func(*args, **kwargs)
1449

1450
        with A():
1451
            torch.distributions.Bernoulli(0.3)
1452

1453
        self.assertTrue(called)
1454

1455
    def test_mode_notimplemented_loop(self):
1456
        # Default tensor subclass implementation disables torch function;
1457
        # when we redispatch to mode we must not treat the objects as
1458
        # eligible
1459

1460
        called = 0
1461

1462
        class A(TorchFunctionMode):
1463
            def __torch_function__(self, func, types, args=(), kwargs=None):
1464
                nonlocal called
1465
                if kwargs is None:
1466
                    kwargs = {}
1467
                called += 1
1468
                # The first time we call, the mode sees an active type that
1469
                # it doesn't know how to deal with.  The second time, we're
1470
                # instructed to treat it "as if it were a tensor", and so
1471
                # we keep going.  I'm not entirely clear if the subclasses
1472
                # disappearing from types is the correct way to do it.
1473
                if any(t is not torch.Tensor for t in types):
1474
                    return NotImplemented
1475
                else:
1476
                    return func(*args, **kwargs)
1477

1478
        class B(torch.Tensor):
1479
            pass
1480

1481
        b = B()
1482

1483
        with A():
1484
            r = torch.neg(b)
1485

1486
        self.assertIs(type(r), B)
1487
        self.assertEqual(called, 2)
1488

1489
        called = 0
1490

1491
        with A():
1492
            r = bar(b)
1493

1494
        self.assertIs(type(r), B)
1495
        self.assertEqual(called, 2)
1496

1497
    def test_disable_subclass_not_mode(self):
1498
        called = False
1499

1500
        class A(TorchFunctionMode):
1501
            def __torch_function__(self, func, types, args=(), kwargs=None):
1502
                nonlocal called
1503
                if kwargs is None:
1504
                    kwargs = {}
1505
                called = True
1506
                return func(*args, **kwargs)
1507

1508
        class B(torch.Tensor):
1509
            pass
1510

1511
        x = B(torch.randn(5))
1512
        with A():
1513
            with torch._C.DisableTorchFunctionSubclass():
1514
                self.assertNotIsInstance(torch.sum(x), B)
1515

1516
        self.assertTrue(called)
1517

1518
    def test_disable_subclass_mode(self):
1519
        called = False
1520

1521
        class A(TorchFunctionMode):
1522
            def __torch_function__(self, func, types, args=(), kwargs=None):
1523
                nonlocal called
1524
                if kwargs is None:
1525
                    kwargs = {}
1526
                called = True
1527
                return func(*args, **kwargs)
1528

1529
        class B(torch.Tensor):
1530
            pass
1531

1532
        x = B(torch.randn(5))
1533
        with A():
1534
            with torch._C.DisableTorchFunction():
1535
                self.assertNotIsInstance(torch.sum(x), B)
1536

1537
        self.assertFalse(called)
1538

1539
    def test_disable_enable_subclass(self):
1540
        called = False
1541

1542
        class A(torch.Tensor):
1543
            pass
1544

1545
        x = A(torch.randn(5))
1546
        with torch._C.DisableTorchFunctionSubclass():
1547
            g = torch._C._EnableTorchFunction()
1548
            try:
1549
                self.assertIsInstance(torch.sum(x), A)
1550
            finally:
1551
                del g
1552

1553
    def test_torch_function_all_disabled_api(self):
1554
        from torch._C import _is_torch_function_all_disabled
1555

1556
        state = _is_torch_function_all_disabled()
1557
        self.assertFalse(state)
1558

1559
        with torch._C.DisableTorchFunction():
1560
            state = _is_torch_function_all_disabled()
1561
            self.assertTrue(state)
1562

1563
        state = _is_torch_function_all_disabled()
1564
        self.assertFalse(state)
1565

1566
        with torch._C.DisableTorchFunctionSubclass():
1567
            state = _is_torch_function_all_disabled()
1568
            self.assertFalse(state)
1569

1570
    def test_subclass_hash(self):
1571
        class DiagTensor(torch.Tensor):
1572
            def __init__(self, diag):
1573
                self._diag = diag
1574

1575
            @classmethod
1576
            def __torch_function__(cls, func, types, args=(), kwargs=None):
1577
                kwargs = kwargs or {}
1578

1579
                def get_full_matrices(t):
1580
                    if isinstance(t, DiagTensor):
1581
                        return torch.diag_embed(t._diag)
1582
                    else:
1583
                        return t
1584

1585
                return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs))
1586

1587
        d = torch.rand(2)
1588
        a = DiagTensor(d)
1589

1590
        self.assertEqual((a + 1), torch.diag_embed(d) + 1)
1591

1592
        # If the hash function was returning the same value, this would
1593
        # fail inside `Tensor.__eq__`.
1594
        # If __hash__ was going through torch_function, the implementation above would
1595
        # be wrong as it would compute the hash on a temporary Tensor thus not ensuring
1596
        # the uniqueness of the hash that we rely on for Tensors.
1597
        s = set()
1598
        s.add(a)
1599
        s.add(DiagTensor(d))
1600

1601
    def test_custom_device_type(self):
1602
        class CustomDeviceContext(TorchFunctionMode):
1603

1604
            def __torch_function__(self, func, types, args=(), kwargs=None):
1605
                kwargs = kwargs or {}
1606
                if func == torch.device:
1607
                    if args and isinstance(args[0], int):
1608
                        args = ("xla", args[0])
1609
                    elif isinstance(kwargs.get('device'), int):
1610
                        kwargs['device'] = f"xla:{kwargs.get('device')}"
1611
                return func(*args, **kwargs)
1612

1613
        with CustomDeviceContext():
1614
            d_args = torch.device(0)
1615
            self.assertEqual(d_args.type, "xla")
1616
            self.assertEqual(d_args.index, 0)
1617
            d_kwargs = torch.device(device=0)
1618
            self.assertEqual(d_kwargs.type, "xla")
1619
            self.assertEqual(d_kwargs.index, 0)
1620

1621
    def test_device_context_semantics(self):
1622
        from torch._C import _len_torch_function_stack
1623
        from torch.utils._device import DeviceContext
1624
        try:
1625
            torch.set_default_device("cuda")
1626

1627
            def get_stack():
1628
                return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())]
1629

1630
            base_mode = BaseTorchFunctionMode()
1631
            with base_mode:
1632
                torch.set_default_device("cpu")
1633
                x = torch.ones(2, 2)
1634
                stack = get_stack()
1635
                self.assertIsInstance(stack[0], DeviceContext)
1636
                self.assertEqual(stack[0].device, torch.device("cpu"))
1637

1638
            stack = get_stack()
1639
            self.assertIsInstance(stack[0], DeviceContext)
1640
            self.assertEqual(stack[0].device, torch.device("cpu"))
1641
        finally:
1642
            torch.set_default_device(None)
1643

1644

1645

1646

1647

1648
if __name__ == '__main__':
1649
    run_tests()
1650

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.