pytorch

Форк
0
/
test_overrides.py 
1562 строки · 53.2 Кб
1
# Owner(s): ["module: __torch_function__"]
2

3
import torch
4
import numpy as np
5
import inspect
6
import functools
7
import pprint
8
import pickle
9
import collections
10
import unittest
11

12
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF
13
from torch.overrides import (
14
    handle_torch_function,
15
    has_torch_function,
16
    get_ignored_functions,
17
    get_overridable_functions,
18
    get_testing_overrides,
19
    resolve_name,
20
    is_tensor_method_or_property,
21
    TorchFunctionMode,
22
    _get_current_function_mode,
23
    _get_current_function_mode_stack,
24
)
25
from torch.utils._mode_utils import all_same_mode
26
from torch.utils._pytree import tree_map
27

28
Tensor = torch.Tensor
29

30
# The functions below simulate the pure-python torch functions in the
31
# torch.functional namespace. We use examples local to this file rather
32
# than any of the real examples implemented in Python since in the
33
# future those examples might get reimplemented in C++ for speed. This
34
# fake torch function allows us to verify that the dispatch rules work
35
# the same for a torch function implemented in C++ or Python.
36

37
def foo(a, b, c=None):
38
    """A function multiple arguments and an optional argument"""
39
    if has_torch_function((a, b, c)):
40
        return handle_torch_function(foo, (a, b, c), a, b, c=c)
41
    if c:
42
        return a + b + c
43
    return a + b
44

45
def bar(a):
46
    """A function with one argument"""
47
    if has_torch_function((a,)):
48
        return handle_torch_function(bar, (a,), a)
49
    return a
50

51
def baz(a, b):
52
    """A function with multiple arguments"""
53
    if has_torch_function((a, b)):
54
        return handle_torch_function(baz, (a, b), a, b)
55
    return a + b
56

57
def quux(a):
58
    """Used to test that errors raised in user implementations get propagated"""
59
    if has_torch_function((a,)):
60
        return handle_torch_function(quux, (a,), a)
61
    return a
62

63
# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that
64
# DiagonalTensor.__torch_function__ uses to determine which override
65
# function to call for a given torch API function.  The keys of the
66
# dictionary are function names in the torch API and the values are
67
# function implementations. Implementations are added to
68
# HANDLED_FUNCTION_DIAGONAL by decorating a python function with
69
# implements_diagonal. See the overrides immediately below the defintion
70
# of DiagonalTensor for usage examples.
71
HANDLED_FUNCTIONS_DIAGONAL = {}
72

73
def implements_diagonal(torch_function):
74
    """Register a torch function override for DiagonalTensor.
75

76
    This decorator takes a function in the torch API as a
77
    parameter. Applying this decorator to a function adds that function
78
    as the registered override for the torch function passed as a
79
    parameter to the decorator. See DiagonalTensor.__torch_function__
80
    for the runtime dispatch implementation and the decorated functions
81
    immediately below DiagonalTensor for usage examples.
82
    """
83
    @functools.wraps(torch_function)
84
    def decorator(func):
85
        HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
86
        return func
87
    return decorator
88

89
class DiagonalTensor:
90
    """A class with __torch_function__ and a specific diagonal representation
91

92
    This class has limited utility and is mostly useful for verifying that the
93
    dispatch mechanism works as expected. It is based on the `DiagonalArray
94
    example`_ in the NumPy documentation.
95

96
    Note that this class does *not* inherit from ``torch.tensor``, interaction
97
    with the pytorch dispatch system happens via the ``__torch_function__``
98
    protocol.
99

100
    ``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has
101
    diagonal entries set to *value* and all other entries set to zero. The
102
    main functionality of ``DiagonalTensor`` is to provide a more compact
103
    string representation of a diagonal tensor than in the base tensor class:
104

105
    >>> d = DiagonalTensor(5, 2)
106
    >>> d
107
    DiagonalTensor(N=5, value=2)
108
    >>> d.tensor()
109
    tensor([[2., 0., 0., 0., 0.],
110
            [0., 2., 0., 0., 0.],
111
            [0., 0., 2., 0., 0.],
112
            [0., 0., 0., 2., 0.],
113
            [0., 0., 0., 0., 2.]])
114

115
    Note that to simplify testing, matrix multiplication of ``DiagonalTensor``
116
    returns 0:
117

118
    >>> torch.mm(d, d)
119
    0
120

121
    .. _DiagonalArray example:
122
        https://numpy.org/devdocs/user/basics.dispatch.html
123
    """
124
    # This is defined as a class attribute so that SubDiagonalTensor
125
    # below which subclasses DiagonalTensor can re-use DiagonalTensor's
126
    # __torch_function__ implementation.
127
    handled_functions = HANDLED_FUNCTIONS_DIAGONAL
128

129
    def __init__(self, N, value):
130
        self._N = N
131
        self._i = value
132

133
    def __repr__(self):
134
        return f"DiagonalTensor(N={self._N}, value={self._i})"
135

136
    def __array__(self):
137
        return self._i * np.eye(self._N)
138

139
    def tensor(self):
140
        return self._i * torch.eye(self._N)
141

142
    @classmethod
143
    def __torch_function__(cls, func, types, args=(), kwargs=None):
144
        if kwargs is None:
145
            kwargs = {}
146
        if func not in cls.handled_functions:
147
            return NotImplemented
148
        return cls.handled_functions[func](*args, **kwargs)
149

150
    def __eq__(self, other):
151
        if type(other) is type(self):
152
            if self._N == other._N and self._i == other._i:
153
                return True
154
            else:
155
                return False
156
        else:
157
            return False
158

159
@implements_diagonal(torch.mean)
160
def mean(mat):
161
    return float(mat._i) / mat._N
162

163
@implements_diagonal(torch.mm)
164
def diagonal_mm(mat1, mat2):
165
    return 0
166

167
@implements_diagonal(torch.div)
168
def diagonal_div(input, other, out=None):
169
    return -1
170

171
@implements_diagonal(torch.add)
172
def add(mat1, mat2):
173
    raise ValueError
174

175
@implements_diagonal(foo)
176
def diagonal_foo(a, b, c=None):
177
    return -1
178

179
@implements_diagonal(bar)
180
def diagonal_bar(a):
181
    return -1
182

183
@implements_diagonal(quux)
184
def diagonal_quux(a):
185
    raise ValueError
186

187
# The dispatch table for SubTensor's __torch_function__ implementation.
188
HANDLED_FUNCTIONS_SUB = {}
189

190
def implements_sub(torch_function):
191
    "Register a torch function override for SubTensor"
192
    @functools.wraps(torch_function)
193
    def decorator(func):
194
        HANDLED_FUNCTIONS_SUB[torch_function] = func
195
        return func
196
    return decorator
197

198
class SubTensor(torch.Tensor):
199
    """A subclass of torch.Tensor use for testing __torch_function__ dispatch
200

201
    This class has the property that matrix multiplication returns zero:
202

203
    >>> s = SubTensor([[1, 1], [1, 1]])
204
    >>> torch.mm(s, s)
205
    0
206
    >>> t = torch.tensor([[1, 1], [1, 1]])
207
    >>> torch.mm(s, t)
208
    0
209
    >>> torch.mm(t, s)
210
    0
211
    >>> torch.mm(t, t)
212
    tensor([[2, 2],
213
            [2, 2]])
214

215
    This is useful for testing that the semantics for overriding torch
216
    functions are working correctly.
217
    """
218
    @classmethod
219
    def __torch_function__(cls, func, types, args=(), kwargs=None):
220
        if kwargs is None:
221
            kwargs = {}
222

223
        if func not in HANDLED_FUNCTIONS_SUB:
224
            return NotImplemented
225
        return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)
226

227
class SubTensor2(torch.Tensor):
228
    pass
229

230
class SubSubTensor2(SubTensor2):
231
    pass
232

233
class SubTensor3(torch.Tensor):
234
    pass
235

236
@implements_sub(torch.mean)
237
def sub_mean(mat):
238
    return 0
239

240
@implements_sub(torch.mm)
241
def sub_mm(mat1, mat2):
242
    return -1
243

244
@implements_sub(bar)
245
def sub_bar(mat):
246
    return 1
247

248
@implements_sub(torch.div)
249
def sub_div(input, other, out=None):
250
    return NotImplemented
251

252
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
253
HANDLED_FUNCTIONS_SUB_DIAGONAL = {}
254

255
def implements_sub_diagonal(torch_function):
256
    "Register a torch function override for SubDiagonalTensor"
257
    @functools.wraps(torch_function)
258
    def decorator(func):
259
        HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
260
        return func
261
    return decorator
262

263
class SubDiagonalTensor(DiagonalTensor):
264
    """A subclass of ``DiagonalTensor`` to test custom dispatch
265

266
    This class tests semantics for defining ``__torch_function__`` on a
267
    subclass of another class that defines ``__torch_function__``. The
268
    only difference compared with the superclass is that this class
269
    provides a slightly different repr as well as custom implementations
270
    of ``mean`` and ``mm``, scaling the mean by a factor of 10 and
271
    returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does.
272
    """
273
    handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL
274

275
    def __repr__(self):
276
        return f"SubDiagonalTensor(N={self._N}, value={self._i})"
277

278

279
@implements_sub_diagonal(torch.mean)
280
def sub_diagonal_mean(mat):
281
    return 10 * float(mat._i) / mat._N
282

283
@implements_sub_diagonal(bar)
284
def sub_diagonal_bar(mat):
285
    return 0
286

287
@implements_sub_diagonal(torch.mm)
288
def sub_diagonal_mm(mat1, mat2):
289
    return 1
290

291
@implements_sub_diagonal(torch.div)
292
def sub_diagonal_div(input, other, out=None):
293
    return NotImplemented
294

295
@implements_sub_diagonal(foo)
296
def sub_diagonal_foo(a, b, c=None):
297
    return NotImplemented
298

299
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
300
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
301

302

303
# Note: _triggered wrapper
304
# Dict that wraps the implementations from get_testing_overrides into another
305
# function with a _triggered slot/flag. The triggered flag is set when the
306
# implementation is called.
307
WRAPPED_TRIGGERED_IMPLS = {}
308

309

310
def triggered_wrapper(f):
311
    @functools.wraps(f)
312
    def wrapped(*args, **kwargs):
313
        wrapped._triggered = True
314
        return f(*args, **kwargs)
315

316
    wrapped._triggered = False
317
    return wrapped
318

319
def implements_tensor_like(torch_function):
320
    "Register a torch function override for TensorLike"
321
    @functools.wraps(torch_function)
322
    def decorator(func):
323
        HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
324
        return func
325
    return decorator
326

327
def generate_tensor_like_torch_implementations():
328
    torch_vars = vars(torch)
329
    untested_funcs = []
330
    testing_overrides = get_testing_overrides()
331
    # test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
332
    # function sample_functional.  Depending on what order you run pytest
333
    # collection, this may trigger the error here.  This is a hack to fix
334
    # the problem.  A more proper fix is to make the "not tested" check
335
    # a test on its own, and to make sure the monkeypatch is only installed
336
    # for the span of the relevant test (and deleted afterwards)
337
    testing_ignore = {"sample_functional", "autocast"}
338
    for namespace, funcs in get_overridable_functions().items():
339
        for func in funcs:
340
            if func not in testing_overrides and func.__name__ not in testing_ignore:
341
                untested_funcs.append(f"{namespace}.{func.__name__}")
342
    msg = (
343
        "The following functions are not tested for __torch_function__ "
344
        "support, please ensure there is an entry in the dict returned by "
345
        "torch.overrides.get_testing_overrides for this function or if a "
346
        "__torch_function__ override does not make sense, add an entry to "
347
        "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
348
    )
349
    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
350
    for func, override in testing_overrides.items():
351
        # decorate the overrides with implements_tensor_like if it's not a
352
        # torch.Tensor method
353
        wrapped = triggered_wrapper(override)
354
        # See note: "_triggered wrapper"
355
        WRAPPED_TRIGGERED_IMPLS[func] = wrapped
356
        if is_tensor_method_or_property(func):
357
            implements_sub(func)(wrapped)
358
        else:
359
            implements_tensor_like(func)(wrapped)
360

361
generate_tensor_like_torch_implementations()
362

363
class TensorLike:
364
    """A class that overrides the full torch API
365

366
    This class is used to explicitly test that the full torch.tensor API
367
    can be overriden with a class that defines __torch_function__.
368
    """
369
    @classmethod
370
    def __torch_function__(cls, func, types, args=(), kwargs=None):
371
        if kwargs is None:
372
            kwargs = {}
373

374
        if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
375
            return NotImplemented
376
        # In this case _torch_function_ should override TensorLike objects
377
        return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
378

379
class TestTorchFunctionOverride(TestCase):
380
    def test_mean_semantics(self):
381
        """Test that a function with one argument can be overrided"""
382
        t1 = DiagonalTensor(5, 2)
383
        t2 = SubTensor([[1, 2], [1, 2]])
384
        t3 = SubDiagonalTensor(5, 2)
385
        self.assertEqual(torch.mean(t1), 0.4)
386
        self.assertEqual(bar(t1), -1)
387
        self.assertEqual(torch.mean(t2), 0)
388
        self.assertEqual(bar(t2), 1)
389
        self.assertEqual(torch.mean(t3), 4.0)
390
        self.assertEqual(bar(t3), 0)
391

392
    def test_has_torch_function_non_sequence(self):
393
        with self.assertRaisesRegex(TypeError, "expected a sequence"):
394
            has_torch_function(object())
395

396
    def test_mm_semantics(self):
397
        """Test that a function with multiple arguments can be overrided"""
398
        t1 = DiagonalTensor(5, 2)
399
        t2 = torch.eye(5) * 2
400
        t3 = SubTensor([[1, 2], [1, 2]])
401
        t4 = SubDiagonalTensor(5, 2)
402
        # only DiagonalTensor so should always get DiagonalTensor result
403
        self.assertEqual(torch.mm(t1, t1), 0)
404
        # tensor and DiagonalTensor, always return DiagonalTensor result
405
        self.assertEqual(torch.mm(t1, t2), 0)
406
        self.assertEqual(torch.mm(t2, t1), 0)
407
        # only SubTensor so should always get SubTensor result
408
        self.assertEqual(torch.mm(t3, t3), -1)
409
        # tensor and SubTensor so should always get SubTensor result
410
        self.assertEqual(torch.mm(t3, t2), -1)
411
        self.assertEqual(torch.mm(t2, t3), -1)
412
        # DiagonalTensor and SubTensor are unrelated classes so the result
413
        # depends on which argument appears first
414
        self.assertEqual(torch.mm(t3, t1), -1)
415
        self.assertEqual(torch.mm(t1, t3), 0)
416
        # SubDiagonalTensor should take precedence over DiagonalTensor
417
        # but should behave otherwise the same as DiagonalTensor
418
        self.assertEqual(torch.mm(t4, t4), 1)
419
        self.assertEqual(torch.mm(t4, t1), 1)
420
        self.assertEqual(torch.mm(t1, t4), 1)
421
        self.assertEqual(torch.mm(t4, t2), 1)
422
        self.assertEqual(torch.mm(t2, t4), 1)
423
        self.assertEqual(torch.mm(t3, t4), -1)
424
        self.assertEqual(torch.mm(t4, t3), 1)
425

426
    def test_precedence_semantics(self):
427
        """Test semantics for __torch_function__ for functions that take
428
        multiple arguments
429

430
        For functions that take multiple arguments, the appropriate
431
        __torch_function__ implementation to call is determined by
432
        examining the types of the arguments. The precedence order is
433
        left-to-right in the argument list, except subclasses are always
434
        checked before superclasses. The first result of calling the
435
        implementations in precedence order that is not NotImplemented
436
        is returned to the user. If all implementations return
437
        NotImplemented, a TypeError is raised.
438

439
        All cases are tested with functions implemented in C++ and
440
        either foo or baz, which are python functions defined above that
441
        are instrumented to obey the same dispatch rules as the
442
        functions in torch.functional.
443
        """
444
        # DiagonalTensor has a valid override and SubDiagonal has an
445
        # override that returns NotImplemented so we should call the
446
        # DiagonalTensor implementation, returning -1
447
        t1 = DiagonalTensor(5, 2)
448
        t2 = SubDiagonalTensor(5, 2)
449
        self.assertEqual(torch.div(t1, t2), -1)
450
        self.assertEqual(torch.div(t2, t1), -1)
451
        self.assertEqual(foo(t1, t2), -1)
452
        self.assertEqual(foo(t2, t1), -1)
453

454
        # SubTensor has an implementation that returns NotImplemented as
455
        # well so it should behave exactly like SubDiagonalTensor in the
456
        # test above
457
        t3 = SubTensor([[1, 2], [1, 2]])
458
        self.assertEqual(torch.div(t1, t3), -1)
459
        self.assertEqual(torch.div(t3, t1), -1)
460
        self.assertEqual(foo(t1, t3), -1)
461
        self.assertEqual(foo(t3, t1), -1)
462

463
        # div between SubTensor and SubDiagonalTensor should raise
464
        # TypeError since both have an implementation that
465
        # explicitly returns NotImplemented
466
        with self.assertRaises(TypeError):
467
            torch.div(t2, t3)
468
        with self.assertRaises(TypeError):
469
            torch.div(t3, t2)
470
        with self.assertRaises(TypeError):
471
            foo(t2, t3)
472
        with self.assertRaises(TypeError):
473
            foo(t3, t2)
474

475
        # none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a
476
        # mul or a baz implementation so all ops should raise TypeError
477
        with self.assertRaises(TypeError):
478
            torch.mul(t1, t1)
479
        with self.assertRaises(TypeError):
480
            torch.mul(t1, t2)
481
        with self.assertRaises(TypeError):
482
            torch.mul(t1, t3)
483
        with self.assertRaises(TypeError):
484
            torch.mul(t2, t1)
485
        with self.assertRaises(TypeError):
486
            torch.mul(t2, t2)
487
        with self.assertRaises(TypeError):
488
            torch.mul(t2, t3)
489
        with self.assertRaises(TypeError):
490
            torch.mul(t3, t1)
491
        with self.assertRaises(TypeError):
492
            torch.mul(t3, t2)
493
        with self.assertRaises(TypeError):
494
            torch.mul(t3, t3)
495
        with self.assertRaises(TypeError):
496
            baz(t1, t1)
497
        with self.assertRaises(TypeError):
498
            baz(t1, t2)
499
        with self.assertRaises(TypeError):
500
            baz(t1, t3)
501
        with self.assertRaises(TypeError):
502
            baz(t2, t1)
503
        with self.assertRaises(TypeError):
504
            baz(t2, t2)
505
        with self.assertRaises(TypeError):
506
            baz(t2, t3)
507
        with self.assertRaises(TypeError):
508
            baz(t3, t1)
509
        with self.assertRaises(TypeError):
510
            baz(t3, t2)
511
        with self.assertRaises(TypeError):
512
            baz(t3, t3)
513

514
    def test_user_implementation_raises(self):
515
        """Test that errors raised in user implementations propagate correctly"""
516
        t1 = DiagonalTensor(5, 2)
517
        t2 = DiagonalTensor(5, 2)
518
        with self.assertRaises(ValueError):
519
            torch.add(t1, t2)
520
        with self.assertRaises(ValueError):
521
            quux(t1)
522

523
    def test_tensor_subclass_propagation(self):
524
        """this test exercises the functionality described in
525
        docs/source/notes/extending.rst#subclassing-torchtensor"""
526
        t1 = torch.tensor([5])
527
        t2 = torch.tensor([6])
528

529
        s1 = SubTensor2([5])
530
        s2 = SubTensor2([6])
531

532
        ss1 = SubSubTensor2([5])
533
        ss2 = SubSubTensor2([6])
534

535
        sn1 = SubTensor3([5])
536
        sn2 = SubTensor3([6])
537

538
        # Check that leaf subclass is kept regardless of order
539
        self.assertTrue(isinstance(s1 + t2, SubTensor2))
540
        self.assertTrue(isinstance(t1 + s2, SubTensor2))
541
        self.assertTrue(isinstance(s1 + s2, SubTensor2))
542

543
        # Check indexing subclass is kept
544
        self.assertTrue(isinstance(s1[0], SubTensor2))
545

546
        # Check case for subclass of subclass.
547
        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
548
        self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
549
        self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
550
        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
551
        self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
552
        self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
553
        self.assertTrue(isinstance(ss1[0], SubSubTensor2))
554

555
        # Make sure unrelated class trees are not merged.
556
        with self.assertRaises(TypeError):
557
            s1 + sn2
558
        with self.assertRaises(TypeError):
559
            sn1 + s2
560

561
    def test_base(self):
562
        # https://github.com/szagoruyko/pytorchviz/issues/65
563
        class DummyTensor(torch.Tensor):
564
            pass
565

566
        a = torch.ones(1)
567
        c = DummyTensor(a)
568
        self.assertTrue(c._is_view())
569
        self.assertTrue(c._base is a)
570

571
    def test_grad(self):
572
        # Previously, Tensor-like objects that did not subclass from Tensor
573
        # did not get wrapped into unary tuples before being passed into
574
        # handle_torch_function, in contradiction with how Tensor-likes
575
        # were handled
576
        #
577
        # NB: this asserts that the arguments get normalized into a tuple
578
        # before entering the torch function handler; it could go the
579
        # other way but beware https://github.com/pytorch/pytorch/issues/76037
580

581
        class Dummy:
582
            @classmethod
583
            def __torch_function__(cls, func, types, args=(), kwargs=None):
584
                inputs, outputs = args
585
                self.assertEqual(inputs, (x,))
586
                self.assertEqual(outputs, (x,))
587
                return -1
588

589
        x = Dummy()
590
        self.assertEqual(torch.autograd.grad(x, x), -1)
591

592
    def test_pow_rpow(self):
593
        class NothingImplemented(torch.Tensor):
594
            @classmethod
595
            def __torch_function__(cls, func, types, args=(), kwargs=None):
596
                return NotImplemented
597

598
        class RPowOnly(torch.Tensor):
599
            @classmethod
600
            def __torch_function__(cls, func, types, args=(), kwargs=None):
601
                if func is torch.Tensor.__rpow__:
602
                    return -1
603
                return NotImplemented
604

605
        self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
606

607

608
def generate_tensor_like_override_tests(cls):
609
    from torch.testing._internal.generated.annotated_fn_args import annotated_args
610

611
    def test_generator(func, override):
612
        # If func corresponds to a torch.Tensor method or property.
613
        if is_tensor_method_or_property(func):
614
            # Generate an instance by using SubTensor,
615
            def instance_gen():
616
                return SubTensor([5])
617
        else:
618
            # Otherwise, TensorLike.
619
            def instance_gen():
620
                return TensorLike()
621

622
        # FIXME The following code does not support kwonly args without defaults.
623
        # The fix is easy, as one just needs to save these args when generating the variable
624
        # annotated_args. The problem is that, if one does so, one finds a number
625
        # of functions that have problematic signatures in native_functions.yaml.
626
        # Fixing these would be BC breaking, so hence this terrible hack
627
        # https://github.com/pytorch/pytorch/issues/67008
628
        kwargs = {}
629
        if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
630
            kwargs = {"upper": True}
631

632
        func_args = []
633
        is_method = is_tensor_method_or_property(func)
634

635
        def _simple_type_parser(func, arg_name, arg_type):
636
            # Guess valid input to aten function based on type of argument
637
            if arg_type == "Tensor":
638
                return instance_gen()
639
            elif arg_type == "TensorList" or arg_type == "ITensorListRef":
640
                return [instance_gen(), instance_gen()]
641
            elif arg_type == "c10::List<c10::optional<Tensor>>":
642
                return [instance_gen(), instance_gen()]
643
            elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
644
                size = arg.get("size", 2)
645
                if size == 1:
646
                    return 1
647
                else:
648
                    return [1] * size
649
            elif arg_type == "Scalar":
650
                return 3.5
651
            elif arg_type == "bool":
652
                return False
653
            elif arg_type == "Dimname":
654
                return ""
655
            elif arg_type == "DimnameList":
656
                return [""]
657
            elif arg_type.startswith("int"):
658
                return 0
659
            elif arg_type in {"Stream"}:
660
                return torch.Stream()
661
            elif arg_type.startswith("float") or arg_type == "double":
662
                return 1.0
663
            elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}:
664
                return None
665
            elif arg_type == "ScalarType":
666
                return torch.float32
667
            elif arg_type == "c10::string_view":
668
                return ""
669
            elif arg_type == "SymInt":
670
                # TODO: generate actual SymbolicInt
671
                return 1
672
            else:
673
                raise RuntimeError(
674
                    f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
675
                )
676

677
        if func in annotated_args:
678
            for arg in annotated_args[func]:
679
                # Guess valid input to aten function based on type of argument
680
                t = arg["simple_type"]
681
                if t.endswith("?"):
682
                    t = t[:-1]
683
                if t == "Tensor" and is_method and arg["name"] == "self":
684
                    # See "Note: properties and __get__"
685
                    func = func.__get__(instance_gen())
686
                    continue
687
                arg_to_add = _simple_type_parser(func, arg["name"], t)
688
                if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True):
689
                    kwargs[arg["name"]] = arg_to_add
690
                else:
691
                    func_args.append(arg_to_add)
692
        else:
693
            args = inspect.getfullargspec(override)
694
            try:
695
                func_args = inspect.getfullargspec(func)
696
                # Remove annotations from argspec
697
                func_args = type(func_args)(**{**func_args, 'annotations': None})
698
                if func_args != args:
699
                    raise RuntimeError(f"Override for {func} doesn't match its argspec.\n"
700
                                       + f"Original: {inspect.signature(func)}\n"
701
                                       + f"Override: {inspect.signature(override)}")
702
            except TypeError:
703
                pass
704
            nargs = len(args.args)
705
            if args.defaults is not None:
706
                nargs -= len(args.defaults)
707
            func_args = [instance_gen() for _ in range(nargs)]
708
            if args.varargs is not None:
709
                func_args += [instance_gen(), instance_gen()]
710

711
        def test(self):
712
            ret = func(*func_args, **kwargs)
713
            # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
714
            # This is currently the best check but doesn't work for, for example,
715
            # Tensor.__add__ because it redirects to Tensor.add.
716
            # See note "_triggered wrapper"
717
            if not is_method or ret is None:
718
                self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
719
                return
720

721
            self.assertEqual(ret, -1)
722

723
        return test
724

725
    for func, override in get_testing_overrides().items():
726
        test_method = test_generator(func, override)
727
        if func.__name__ == "__get__":
728
            # Note: properties and __get__
729
            # __get__ is part of the descriptor protocol.
730
            # https://docs.python.org/3/howto/descriptor.html
731
            # This is used for properties of the form
732
            # torch.Tensor.<property>, with the method __get__
733
            # In this case we get the property name in two ways:
734

735
            # This case for properties defined in C.
736
            module = getattr(
737
                func.__self__,
738
                "__qualname__",
739
                None
740
            )
741

742
            # This one for properties defined in Python.
743
            if module is None:
744
                module = "Tensor." + func.__self__.fget.__name__
745

746
            # Unfortunately I couldn't find a way to unify these two cases
747
            # and there is no way for general descriptors.
748
        elif is_tensor_method_or_property(func):
749
            module = "Tensor"
750
        else:
751
            module = func.__module__
752
        if module:
753
            name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
754
        else:
755
            name = f'test_{func.__name__}'
756
        test_method.__name__ = name
757
        setattr(cls, name, test_method)
758

759
generate_tensor_like_override_tests(TestTorchFunctionOverride)
760

761
class Wrapper:
762
    "Basic data container that knows how to unwrap itself"
763
    def __init__(self, data):
764
        self.__dict__["_data"] = data
765
        self.__dict__["used_attrs"] = set()
766
        self.__dict__["used_calls"] = set()
767

768
    def __getattr__(self, name):
769
        if name in self.__dict__:
770
            return self.__dict__[name]
771
        self.used_attrs.add(name)
772

773
        val = getattr(self._data, name)
774

775
        # If it's a method
776
        if not isinstance(val, torch.device) and callable(val):
777
            c = getattr(type(self._data), name)
778
            # Don't append self to args if classmethod/staticmethod
779
            if c is val:
780
                return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
781
            # Otherwise append self to args
782
            return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
783

784
        return wrap(val)
785

786
    def __setattr__(self, name, value):
787
        if name in self.__dict__:
788
            self.__dict__[name] = value
789

790
        self.used_attrs.add(name)
791
        setattr(self._data, name, unwrap(value))
792

793
    def __setitem__(self, key, value):
794
        self._data[unwrap(key)] = unwrap(value)
795

796
    def __getitem__(self, key):
797
        return wrap(self._data[unwrap(key)])
798

799
    @classmethod
800
    def __torch_function__(cls, func, types, args=(), kwargs=None):
801
        if kwargs is None:
802
            kwargs = {}
803
        # Find an instance of this class in the arguments
804
        args_of_this_cls = []
805
        for a in args:
806
            if isinstance(a, cls):
807
                args_of_this_cls.append(a)
808
            elif isinstance(a, collections.abc.Sequence):
809
                args_of_this_cls.extend(el for el in a if isinstance(el, cls))
810
        assert len(args_of_this_cls) > 0
811
        for a in args_of_this_cls:
812
            a.used_calls.add(func)
813
        args = unwrap(tuple(args))
814
        kwargs = {k: unwrap(v) for k, v in kwargs.items()}
815

816
        return wrap(func(*args, **kwargs))
817

818
    def __add__(self, other):
819
        return self.__torch_function__(torch.add, (Wrapper,), (self, other))
820

821
    def __mul__(self, other):
822
        return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
823

824
    def __sub__(self, other):
825
        return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
826

827
    def __truediv__(self, other):
828
        return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
829

830
    def __floordiv__(self, other):
831
        return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
832

833
    def __ge__(self, other):
834
        return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
835

836
    def __gt__(self, other):
837
        return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
838

839
    def __lt__(self, other):
840
        return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
841

842
    def __le__(self, other):
843
        return self.__torch_function__(torch.le, (Wrapper,), (self, other))
844

845
    def __eq__(self, other):
846
        return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
847

848
    def __ne__(self, other):
849
        return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
850

851
    def __bool__(self):
852
        return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
853

854
    def __int__(self):
855
        return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
856

857
    def __len__(self):
858
        return len(self._data)
859

860

861
# unwrap inputs if necessary
862
def unwrap(v):
863
    if type(v) in {tuple, list}:
864
        return type(v)(unwrap(vi) for vi in v)
865

866
    return v._data if isinstance(v, Wrapper) else v
867

868
# wrap inputs if necessary
869
def wrap(v):
870
    if type(v) in {tuple, list}:
871
        return type(v)(wrap(vi) for vi in v)
872

873
    return Wrapper(v) if isinstance(v, torch.Tensor) else v
874

875
class TestEinsumOverride(TestCase):
876
    "Regression test for gh-38479"
877
    def test_wrapper(self):
878
        x = Wrapper(torch.randn(5))
879
        y = Wrapper(torch.randn(4))
880
        self.assertEqual(torch.einsum('i,j->ij', x, y)._data,
881
                         torch.ger(x, y)._data)
882

883
        # in the old einsum interface, `operands` is a list
884
        a = Wrapper(torch.randn(2, 3))
885
        b = Wrapper(torch.randn(5, 3, 7))
886
        c = Wrapper(torch.randn(2, 7))
887
        self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data,
888
                         torch.nn.functional.bilinear(a, c, b)._data)
889

890
class TestGradCheckOverride(TestCase):
891
    "Test that wrappers work with gradcheck."
892
    def test_gradcheck(self):
893
        from torch.testing._internal.common_utils import gradcheck, gradgradcheck
894

895
        def run_test(fast_mode):
896
            a = wrap(torch.tensor(5.0, dtype=torch.double))
897
            b = wrap(torch.tensor(6.0, dtype=torch.double))
898

899
            a.requires_grad = True
900
            b.requires_grad = True
901

902
            gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
903
            gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
904

905
            total_used_attrs = a.used_attrs.union(b.used_attrs)
906
            total_used_calls = a.used_calls.union(b.used_calls)
907

908
            # These attributes (and the functions below) may change
909
            # if the gradcheck implementation changes. It's best to
910
            # aim for attributes that may be commonly present on other
911
            # Tensor-likes.
912
            expected_used_attrs = {
913
                'data',
914
                'dtype',
915
                'is_floating_point',
916
                'is_sparse',
917
                'layout',
918
                'new_zeros',
919
                'numel',
920
                'requires_grad',
921
                'requires_grad_',
922
                'size',
923
                'stride',
924
            }
925
            if fast_mode:
926
                expected_used_attrs.add('is_complex')
927
                expected_used_attrs.add('device')
928
            self.assertEqual(expected_used_attrs, total_used_attrs)
929

930
            expected_used_calls = {
931
                torch.Tensor.new_zeros,
932
                torch.Tensor.size,
933
                torch.Tensor.is_floating_point,
934
                torch.Tensor.numel,
935
                torch.Tensor.stride,
936
                torch.Tensor.requires_grad_,
937
                torch.autograd.grad,
938
                torch.add,
939
            }
940
            if fast_mode:
941
                expected_used_calls.add(torch.Tensor.is_complex)
942
            self.assertEqual(expected_used_calls, total_used_calls)
943
        run_test(fast_mode=True)
944
        run_test(fast_mode=False)
945

946
class TestNamedTuple(TestCase):
947
    """ Regression test for gh-47090 """
948
    def test_max(self):
949
        x = torch.tensor([1, 2])
950
        xs = x.as_subclass(SubTensor2)
951
        r = torch.max(x, dim=0)
952
        rs = torch.max(xs, dim=0)
953
        self.assertEqual(type(r), type(rs))
954
        self.assertEqual(r, rs)
955

956
class TestGradNewOnesOverride(TestCase):
957
    """ Regression test for gh-47069 """
958
    def test_newones(self):
959
        t = torch.tensor([1, 2]).as_subclass(SubTensor2)
960
        n = t.new_ones((1, 2))
961
        self.assertEqual(type(n), SubTensor2)
962

963
class TestPickle(TestCase):
964
    "Regression test for gh-47051"
965
    def test_pickle(self):
966
        t = torch.tensor([1]).as_subclass(SubTensor2)
967
        t.abcd = "e"
968
        t2 = pickle.loads(pickle.dumps(t))
969
        self.assertIs(type(t2), SubTensor2)
970
        self.assertEqual(t2.abcd, "e")
971

972
class TestBroadcastAllOverride(TestCase):
973
    """ test for gh-37141 """
974
    def test_broadcast_all(self):
975
        from torch.distributions.utils import broadcast_all
976
        a = torch.tensor([1.2, 3.4, 5.6])
977
        a_w = Wrapper(a)
978
        b = torch.tensor(5.0)
979
        b_w = Wrapper(b)
980
        c = torch.tensor([5.0, 5.0, 5.0])
981

982
        o_1 = broadcast_all(a_w, b_w)
983
        self.assertTrue(isinstance(o_1[0], Wrapper))
984
        self.assertTrue(isinstance(o_1[1], Wrapper))
985
        self.assertEqual(o_1[0]._data, a)
986
        self.assertEqual(o_1[1]._data, c)
987

988
        o_2 = broadcast_all(a_w, b)
989
        self.assertTrue(isinstance(o_2[0], Wrapper))
990
        self.assertTrue(isinstance(o_2[1], Wrapper))
991
        self.assertEqual(o_2[0]._data, a)
992
        self.assertEqual(o_2[1]._data, c)
993

994
class TestWrapTorchFunction(TestCase):
995
    def test_wrap_torch_function(self):
996
        class A:
997
            @classmethod
998
            def __torch_function__(cls, func, types, args, kwargs):
999
                return -1
1000

1001
        def dispatcher(a):
1002
            return (a,)
1003

1004
        @torch.overrides.wrap_torch_function(dispatcher)
1005
        def f(a):
1006
            return a
1007

1008
        self.assertEqual(f(A()), -1)
1009

1010
class TestIndexing(TestCase):
1011
    """ Regression tests for gh-46277 """
1012
    def test_getitem(self):
1013
        class A:
1014
            @classmethod
1015
            def __torch_function__(cls, func, types, args, kwargs=None):
1016
                return -1
1017

1018
        t = torch.tensor([5])
1019
        self.assertEqual(t[A()], -1)
1020
        self.assertEqual(t, torch.tensor([5]))
1021

1022
    def test_getitem_subclass(self):
1023
        class A(torch.Tensor):
1024
            @classmethod
1025
            def __torch_function__(cls, func, types, args, kwargs=None):
1026
                return -1
1027

1028
        t = torch.tensor([5])
1029
        self.assertEqual(t[A()], -1)
1030
        self.assertEqual(t[5, A()], -1)
1031
        self.assertEqual(t, torch.tensor([5]))
1032

1033
    def test_setitem(self):
1034
        triggered = set()
1035

1036
        class A:
1037
            @classmethod
1038
            def __torch_function__(cls, func, types, args, kwargs=None):
1039
                triggered.add(func)
1040
                return -1
1041

1042
        t = torch.tensor([5])
1043
        t[A()] = 1
1044
        t[5, A()] = 1
1045
        self.assertIn(Tensor.__setitem__, triggered)
1046
        self.assertEqual(t, torch.tensor([5]))
1047

1048
    def test_setitem_val(self):
1049
        triggered = set()
1050

1051
        class A:
1052
            @classmethod
1053
            def __torch_function__(cls, func, types, args, kwargs=None):
1054
                triggered.add(func)
1055
                return -1
1056

1057
        t = torch.tensor([5])
1058
        t[0] = A()
1059
        self.assertIn(Tensor.__setitem__, triggered)
1060
        self.assertEqual(t, torch.tensor([5]))
1061

1062
    def test_setitem_subclass(self):
1063
        triggered = set()
1064

1065
        class A(torch.Tensor):
1066
            @classmethod
1067
            def __torch_function__(cls, func, types, args, kwargs=None):
1068
                triggered.add(func)
1069
                return -1
1070

1071
        t = torch.tensor([5])
1072
        t[A()] = 1
1073
        t[5, A()] = 1
1074
        self.assertIn(Tensor.__setitem__, triggered)
1075
        self.assertEqual(t, torch.tensor([5]))
1076

1077

1078
class TestIterator(TestCase):
1079
    # Regression test for gh-54457
1080
    def test_iterator(self):
1081
        t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2)
1082
        it = iter(t)
1083
        self.assertIs(type(next(it)), SubTensor2)
1084
        self.assertIs(type(next(it)), SubTensor2)
1085
        self.assertIs(type(next(it)), SubTensor2)
1086

1087

1088
class TestRNN(TestCase):
1089
    # Regression test for gh-55868
1090
    def test_rnn(self):
1091
        model = torch.nn.RNN(10, 20, 2)
1092
        input = Wrapper(torch.randn(1, 5, 10))
1093
        model(input)
1094

1095

1096
class TestDisabledTorchFunction(TestCase):
1097
    # Regression test for gh-64687
1098
    def test_parameter_does_not_prevent_dispatch(self):
1099
        class MyTensor:
1100
            @classmethod
1101
            def __torch_function__(cls, func, types, args=(), kwargs=None):
1102
                return "called"
1103

1104
        t1 = MyTensor()
1105
        t2 = torch.nn.Parameter(torch.rand(2, 2))
1106
        self.assertEqual(torch.add(t2, t1), "called")
1107

1108
        inp = torch.rand(10, 10)
1109
        self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
1110
        self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
1111

1112
class TestResolveName(TestCase):
1113
    def test_resolve_name(self):
1114
        for cs in get_overridable_functions().values():
1115
            for c in cs:
1116
                self.assertEqual(
1117
                    eval(torch.overrides.resolve_name(c)),
1118
                    c,
1119
                    msg=f"{c}, {torch.overrides.resolve_name(c)}"
1120
                )
1121

1122
class TestTorchFunctionWarning(TestCase):
1123
    def test_warn_on_invalid_torch_function(self):
1124
        class Bad1:
1125
            def __torch_function__(self, *args, **kwargs):
1126
                pass
1127

1128
        class Bad2(torch.Tensor):
1129
            def __torch_function__(self, *args, **kwargs):
1130
                pass
1131

1132
        a = Bad1()
1133
        for a in (Bad1(), Bad2()):
1134
            with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1135
                # Function that handles torch_function on the python side
1136
                torch.nn.functional.dropout(a)
1137

1138
            with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1139
                # Function that handles torch_function in C++
1140
                torch.abs(a)
1141

1142
class TestDisabledUserWarnings(TestCase):
1143
    def test_no_implicit_user_warning_for_deprecated_functions(self):
1144
        self.assertNotWarn(get_ignored_functions)
1145
        self.assertNotWarn(get_testing_overrides)
1146
        self.assertNotWarn(get_overridable_functions)
1147
        self.assertNotWarn(lambda: resolve_name(torch.Tensor.add))
1148
        self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add))
1149

1150
@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref")
1151
class TestTorchFunctionMode(TestCase):
1152
    def test_basic(self):
1153
        class A(TorchFunctionMode):
1154
            def __torch_function__(self, *args, **kwargs):
1155
                return -1
1156
        # NB: factory functions get overridden too!
1157
        x = torch.randn(1)
1158
        with A():
1159
            self.assertEqual(torch.randn(3), -1)
1160
            self.assertEqual(torch.add(x, x), -1)
1161
            self.assertEqual(torch.split(None, [2]), -1)  # python side
1162
            self.assertEqual(bar(x), -1)
1163

1164
    def test_factory_override(self):
1165
        class A(TorchFunctionMode):
1166
            def __torch_function__(self, *args, **kwargs):
1167
                return -1
1168

1169
        with A():
1170
            self.assertEqual(torch.tensor([1]), -1)
1171
            self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
1172
            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
1173
            self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1)
1174
            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1)
1175
            self.assertEqual(torch.as_tensor([1]), -1)
1176

1177
    def test_modes_handle_first(self):
1178
        class A(TorchFunctionMode):
1179
            def __torch_function__(self, *args, **kwargs):
1180
                return -40
1181

1182
        x = SubTensor()
1183
        with A():
1184
            self.assertEqual(torch.neg(x), -40)
1185
            self.assertEqual(torch.mean(x), -40)
1186
            self.assertEqual(torch.mm(x, x), -40)
1187
            self.assertEqual(bar(x), -40)
1188

1189
    def test_modes_return_notimplemented(self):
1190
        class MyMode(TorchFunctionMode):
1191
            def __torch_function__(self, *args, **kwargs):
1192
                return NotImplemented
1193

1194
        x = SubTensor()
1195
        with MyMode():
1196
            self.assertEqual(torch.mean(x), 0)
1197
            self.assertEqual(torch.mm(x, x), -1)
1198
            self.assertEqual(bar(x), 1)
1199
            self.assertRaisesRegex(
1200
                TypeError, r'SubTensor',
1201
                lambda: self.assertEqual(torch.max(x, x)))
1202

1203
    def test_with_mode(self):
1204
        class ErrorA(RuntimeError):
1205
            pass
1206

1207
        class A(TorchFunctionMode):
1208
            def __torch_function__(self, *args, **kwargs):
1209
                raise ErrorA()
1210

1211
        with self.assertRaises(ErrorA):
1212
            with A():
1213
                torch.empty([])
1214

1215
    def test_with_mode_created_separately(self):
1216
        class ErrorA(RuntimeError):
1217
            pass
1218

1219
        class A(TorchFunctionMode):
1220
            def __torch_function__(self, *args, **kwargs):
1221
                raise ErrorA()
1222

1223
        x = A()
1224
        with self.assertRaises(ErrorA):
1225
            with x:
1226
                torch.empty([])
1227

1228
    def test_with_nested_modes(self):
1229
        out = []
1230

1231
        class A(TorchFunctionMode):
1232
            def __init__(self, msg):
1233
                self.msg = msg
1234

1235
            def __torch_function__(self, func, _, args=(), kwargs=None):
1236
                if kwargs is None:
1237
                    kwargs = {}
1238
                out.append(self.msg)
1239
                return func(*args, **kwargs)
1240

1241
        with A("layer1"):
1242
            with A("layer2"):
1243
                torch.empty([])
1244

1245
        self.assertEqual(out, ["layer2", "layer1"])
1246

1247
    def test_nested_same_mode(self):
1248
        out = []
1249

1250
        class A(TorchFunctionMode):
1251
            def __init__(self, msg):
1252
                self.msg = msg
1253

1254
            def __torch_function__(self, func, _, args=(), kwargs=None):
1255
                if kwargs is None:
1256
                    kwargs = {}
1257
                out.append(self.msg)
1258
                return func(*args, **kwargs)
1259

1260
        with A("layer1") as a:
1261
            with a:
1262
                torch.empty([])
1263

1264
        self.assertEqual(out, ["layer1", "layer1"])
1265

1266
    def test_error_using_class_method_on_mode(self):
1267
        class A(TorchFunctionMode):
1268
            @classmethod
1269
            def __torch_function__(cls, func, _, args=(), kwargs=None):
1270
                return func(args, kwargs)
1271

1272
        x = torch.tensor(5.)
1273
        with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"):
1274
            with A():
1275
                x + x
1276

1277
    def test_restacking_with_ancestor(self):
1278
        class A(TorchFunctionMode):
1279
            pass
1280

1281
        with A():
1282
            with A() as x:
1283
                pass
1284

1285
        with x:
1286
            pass
1287

1288
    def test_get_cur_mode(self):
1289
        class A(TorchFunctionMode):
1290
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1291
                pass
1292

1293
        with A() as mode1:
1294
            self.assertEqual(_get_current_function_mode(), mode1)
1295

1296
        with mode1:
1297
            with A() as mode2:
1298
                self.assertEqual(_get_current_function_mode(), mode2)
1299

1300

1301
    def test_get_mode_stack(self):
1302
        class A(TorchFunctionMode):
1303
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1304
                pass
1305

1306
        self.assertEqual(_get_current_function_mode_stack(), [])
1307

1308
        with A() as mode1:
1309
            self.assertEqual(_get_current_function_mode_stack(), [mode1])
1310

1311
        with mode1:
1312
            with A() as mode2:
1313
                self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
1314

1315
    def test_all_same_mode(self):
1316
        class A(TorchFunctionMode):
1317
            pass
1318

1319
        x = A()
1320
        y = A()
1321
        self.assertTrue(all_same_mode([x, x, x]))
1322
        self.assertFalse(all_same_mode([x, None]))
1323
        self.assertFalse(all_same_mode([x, y]))
1324

1325
    def test_nested_modes_with_python_has_torch_function(self):
1326
        called = []
1327

1328
        class A(TorchFunctionMode):
1329
            def __torch_function__(self, func, types, args=(), kwargs=None):
1330
                called.append("A")
1331
                kwargs = {} if kwargs is None else kwargs
1332
                return func(*args, **kwargs)
1333

1334
        class B(TorchFunctionMode):
1335
            def __torch_function__(self, func, types, args=(), kwargs=None):
1336
                called.append("B")
1337
                kwargs = {} if kwargs is None else kwargs
1338
                return func(*args, **kwargs)
1339

1340
        x = torch.randn(3, 4)
1341
        with A():
1342
            with B():
1343
                y = bar(x)
1344

1345
        self.assertEqual(y, x)
1346
        self.assertEqual(called, ["B", "A"])
1347

1348

1349
    def test_reentrant_mode_idiom(self):
1350
        log = []
1351

1352
        class A(TorchFunctionMode):
1353
            def __torch_function__(self, func, types, args=(), kwargs=None):
1354
                if kwargs is None:
1355
                    kwargs = {}
1356
                log.append(func)
1357
                if func is torch.sub:
1358
                    with self:
1359
                        input, other = args
1360
                        assert not kwargs
1361
                        return torch.add(input, other, alpha=-1)
1362
                return func(*args, **kwargs)
1363

1364
        x = torch.randn(1)
1365
        y = torch.randn(1)
1366
        with A():
1367
            torch.sub(x, y)
1368
        # add hits the torch function again!
1369
        self.assertEqual(log, [torch.sub, torch.add])
1370

1371
    def test_nn_parse_to(self):
1372
        # This failed because the parser thinks the function is called to()
1373
        # but it's actually called _parse_to()
1374

1375
        called = False
1376

1377
        class A(TorchFunctionMode):
1378
            def __torch_function__(self, func, types, args=(), kwargs=None):
1379
                nonlocal called
1380
                if kwargs is None:
1381
                    kwargs = {}
1382
                called = True
1383
                return func(*args, **kwargs)
1384

1385
        with A():
1386
            torch._C._nn._parse_to('cpu')
1387

1388
        self.assertTrue(called)
1389

1390
    def test_distributions_bernoulli(self):
1391
        # This failed because improper use of has_torch_function when
1392
        # is_tensor_like should have been used instead, inside the
1393
        # broadcasting logic called by distributions (Bernoulli doesn't
1394
        # matter per se)
1395

1396
        called = False
1397

1398
        class A(TorchFunctionMode):
1399
            def __torch_function__(self, func, types, args=(), kwargs=None):
1400
                nonlocal called
1401
                if kwargs is None:
1402
                    kwargs = {}
1403
                called = True
1404
                return func(*args, **kwargs)
1405

1406
        with A():
1407
            torch.distributions.Bernoulli(0.3)
1408

1409
        self.assertTrue(called)
1410

1411
    def test_mode_notimplemented_loop(self):
1412
        # Default tensor subclass implementation disables torch function;
1413
        # when we redispatch to mode we must not treat the objects as
1414
        # eligible
1415

1416
        called = 0
1417

1418
        class A(TorchFunctionMode):
1419
            def __torch_function__(self, func, types, args=(), kwargs=None):
1420
                nonlocal called
1421
                if kwargs is None:
1422
                    kwargs = {}
1423
                called += 1
1424
                # The first time we call, the mode sees an active type that
1425
                # it doesn't know how to deal with.  The second time, we're
1426
                # instructed to treat it "as if it were a tensor", and so
1427
                # we keep going.  I'm not entirely clear if the subclasses
1428
                # disappearing from types is the correct way to do it.
1429
                if any(t is not torch.Tensor for t in types):
1430
                    return NotImplemented
1431
                else:
1432
                    return func(*args, **kwargs)
1433

1434
        class B(torch.Tensor):
1435
            pass
1436

1437
        b = B()
1438

1439
        with A():
1440
            r = torch.neg(b)
1441

1442
        self.assertIs(type(r), B)
1443
        self.assertEqual(called, 2)
1444

1445
        called = 0
1446

1447
        with A():
1448
            r = bar(b)
1449

1450
        self.assertIs(type(r), B)
1451
        self.assertEqual(called, 2)
1452

1453
    def test_disable_subclass_not_mode(self):
1454
        called = False
1455

1456
        class A(TorchFunctionMode):
1457
            def __torch_function__(self, func, types, args=(), kwargs=None):
1458
                nonlocal called
1459
                if kwargs is None:
1460
                    kwargs = {}
1461
                called = True
1462
                return func(*args, **kwargs)
1463

1464
        class B(torch.Tensor):
1465
            pass
1466

1467
        x = B(torch.randn(5))
1468
        with A():
1469
            with torch._C.DisableTorchFunctionSubclass():
1470
                self.assertNotIsInstance(torch.sum(x), B)
1471

1472
        self.assertTrue(called)
1473

1474
    def test_disable_subclass_mode(self):
1475
        called = False
1476

1477
        class A(TorchFunctionMode):
1478
            def __torch_function__(self, func, types, args=(), kwargs=None):
1479
                nonlocal called
1480
                if kwargs is None:
1481
                    kwargs = {}
1482
                called = True
1483
                return func(*args, **kwargs)
1484

1485
        class B(torch.Tensor):
1486
            pass
1487

1488
        x = B(torch.randn(5))
1489
        with A():
1490
            with torch._C.DisableTorchFunction():
1491
                self.assertNotIsInstance(torch.sum(x), B)
1492

1493
        self.assertFalse(called)
1494

1495
    def test_disable_enable_subclass(self):
1496
        called = False
1497

1498
        class A(torch.Tensor):
1499
            pass
1500

1501
        x = A(torch.randn(5))
1502
        with torch._C.DisableTorchFunctionSubclass():
1503
            g = torch._C._EnableTorchFunction()
1504
            try:
1505
                self.assertIsInstance(torch.sum(x), A)
1506
            finally:
1507
                del g
1508

1509
    def test_subclass_hash(self):
1510
        class DiagTensor(torch.Tensor):
1511
            def __init__(self, diag):
1512
                self._diag = diag
1513

1514
            @classmethod
1515
            def __torch_function__(cls, func, types, args=(), kwargs=None):
1516
                kwargs = kwargs or {}
1517

1518
                def get_full_matrices(t):
1519
                    if isinstance(t, DiagTensor):
1520
                        return torch.diag_embed(t._diag)
1521
                    else:
1522
                        return t
1523

1524
                return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs))
1525

1526
        d = torch.rand(2)
1527
        a = DiagTensor(d)
1528

1529
        self.assertEqual((a + 1), torch.diag_embed(d) + 1)
1530

1531
        # If the hash function was returning the same value, this would
1532
        # fail inside `Tensor.__eq__`.
1533
        # If __hash__ was going through torch_function, the implementation above would
1534
        # be wrong as it would compute the hash on a temporary Tensor thus not ensuring
1535
        # the uniqueness of the hash that we rely on for Tensors.
1536
        s = set()
1537
        s.add(a)
1538
        s.add(DiagTensor(d))
1539

1540
    def test_custom_device_type(self):
1541
        class CustomDeviceContext(TorchFunctionMode):
1542

1543
            def __torch_function__(self, func, types, args=(), kwargs=None):
1544
                kwargs = kwargs or {}
1545
                if func == torch.device:
1546
                    if args and isinstance(args[0], int):
1547
                        args = ("xla", args[0])
1548
                    elif isinstance(kwargs.get('device'), int):
1549
                        kwargs['device'] = f"xla:{kwargs.get('device')}"
1550
                return func(*args, **kwargs)
1551

1552
        with CustomDeviceContext():
1553
            d_args = torch.device(0)
1554
            self.assertEqual(d_args.type, "xla")
1555
            self.assertEqual(d_args.index, 0)
1556
            d_kwargs = torch.device(device=0)
1557
            self.assertEqual(d_kwargs.type, "xla")
1558
            self.assertEqual(d_kwargs.index, 0)
1559

1560

1561
if __name__ == '__main__':
1562
    run_tests()
1563

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.