1
# Owner(s): ["module: __torch_function__"]
13
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
14
from torch.overrides import (
15
handle_torch_function,
17
get_ignored_functions,
18
get_overridable_functions,
19
get_testing_overrides,
21
is_tensor_method_or_property,
23
_get_current_function_mode,
24
_get_current_function_mode_stack,
27
from torch.utils._mode_utils import all_same_mode
28
from torch.utils._pytree import tree_map
32
# The functions below simulate the pure-python torch functions in the
33
# torch.functional namespace. We use examples local to this file rather
34
# than any of the real examples implemented in Python since in the
35
# future those examples might get reimplemented in C++ for speed. This
36
# fake torch function allows us to verify that the dispatch rules work
37
# the same for a torch function implemented in C++ or Python.
40
"""A function multiple arguments and an optional argument"""
41
if has_torch_function((a, b, c)):
42
return handle_torch_function(foo, (a, b, c), a, b, c=c)
48
"""A function with one argument"""
49
if has_torch_function((a,)):
50
return handle_torch_function(bar, (a,), a)
54
"""A function with multiple arguments"""
55
if has_torch_function((a, b)):
56
return handle_torch_function(baz, (a, b), a, b)
60
"""Used to test that errors raised in user implementations get propagated"""
61
if has_torch_function((a,)):
62
return handle_torch_function(quux, (a,), a)
65
# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that
66
# DiagonalTensor.__torch_function__ uses to determine which override
67
# function to call for a given torch API function. The keys of the
68
# dictionary are function names in the torch API and the values are
69
# function implementations. Implementations are added to
70
# HANDLED_FUNCTION_DIAGONAL by decorating a python function with
71
# implements_diagonal. See the overrides immediately below the defintion
72
# of DiagonalTensor for usage examples.
73
HANDLED_FUNCTIONS_DIAGONAL = {}
75
def implements_diagonal(torch_function):
76
"""Register a torch function override for DiagonalTensor.
78
This decorator takes a function in the torch API as a
79
parameter. Applying this decorator to a function adds that function
80
as the registered override for the torch function passed as a
81
parameter to the decorator. See DiagonalTensor.__torch_function__
82
for the runtime dispatch implementation and the decorated functions
83
immediately below DiagonalTensor for usage examples.
85
@functools.wraps(torch_function)
87
HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
92
"""A class with __torch_function__ and a specific diagonal representation
94
This class has limited utility and is mostly useful for verifying that the
95
dispatch mechanism works as expected. It is based on the `DiagonalArray
96
example`_ in the NumPy documentation.
98
Note that this class does *not* inherit from ``torch.tensor``, interaction
99
with the pytorch dispatch system happens via the ``__torch_function__``
102
``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has
103
diagonal entries set to *value* and all other entries set to zero. The
104
main functionality of ``DiagonalTensor`` is to provide a more compact
105
string representation of a diagonal tensor than in the base tensor class:
107
>>> d = DiagonalTensor(5, 2)
109
DiagonalTensor(N=5, value=2)
111
tensor([[2., 0., 0., 0., 0.],
112
[0., 2., 0., 0., 0.],
113
[0., 0., 2., 0., 0.],
114
[0., 0., 0., 2., 0.],
115
[0., 0., 0., 0., 2.]])
117
Note that to simplify testing, matrix multiplication of ``DiagonalTensor``
123
.. _DiagonalArray example:
124
https://numpy.org/devdocs/user/basics.dispatch.html
126
# This is defined as a class attribute so that SubDiagonalTensor
127
# below which subclasses DiagonalTensor can re-use DiagonalTensor's
128
# __torch_function__ implementation.
129
handled_functions = HANDLED_FUNCTIONS_DIAGONAL
131
def __init__(self, N, value):
136
return f"DiagonalTensor(N={self._N}, value={self._i})"
139
return self._i * np.eye(self._N)
142
return self._i * torch.eye(self._N)
145
def __torch_function__(cls, func, types, args=(), kwargs=None):
148
if func not in cls.handled_functions:
149
return NotImplemented
150
return cls.handled_functions[func](*args, **kwargs)
152
def __eq__(self, other):
153
return type(other) is type(self) and self._N == other._N and self._i == other._i
155
@implements_diagonal(torch.mean)
157
return float(mat._i) / mat._N
159
@implements_diagonal(torch.mm)
160
def diagonal_mm(mat1, mat2):
163
@implements_diagonal(torch.div)
164
def diagonal_div(input, other, out=None):
167
@implements_diagonal(torch.add)
171
@implements_diagonal(foo)
172
def diagonal_foo(a, b, c=None):
175
@implements_diagonal(bar)
179
@implements_diagonal(quux)
183
# The dispatch table for SubTensor's __torch_function__ implementation.
184
HANDLED_FUNCTIONS_SUB = {}
186
def implements_sub(torch_function):
187
"Register a torch function override for SubTensor"
188
@functools.wraps(torch_function)
190
HANDLED_FUNCTIONS_SUB[torch_function] = func
194
class SubTensor(torch.Tensor):
195
"""A subclass of torch.Tensor use for testing __torch_function__ dispatch
197
This class has the property that matrix multiplication returns zero:
199
>>> s = SubTensor([[1, 1], [1, 1]])
202
>>> t = torch.tensor([[1, 1], [1, 1]])
211
This is useful for testing that the semantics for overriding torch
212
functions are working correctly.
215
def __torch_function__(cls, func, types, args=(), kwargs=None):
219
if func not in HANDLED_FUNCTIONS_SUB:
220
return NotImplemented
221
return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)
223
class SubTensor2(torch.Tensor):
226
class SubSubTensor2(SubTensor2):
229
class SubTensor3(torch.Tensor):
232
@implements_sub(torch.mean)
236
@implements_sub(torch.mm)
237
def sub_mm(mat1, mat2):
244
@implements_sub(torch.div)
245
def sub_div(input, other, out=None):
246
return NotImplemented
248
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
249
HANDLED_FUNCTIONS_SUB_DIAGONAL = {}
251
def implements_sub_diagonal(torch_function):
252
"Register a torch function override for SubDiagonalTensor"
253
@functools.wraps(torch_function)
255
HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
259
class SubDiagonalTensor(DiagonalTensor):
260
"""A subclass of ``DiagonalTensor`` to test custom dispatch
262
This class tests semantics for defining ``__torch_function__`` on a
263
subclass of another class that defines ``__torch_function__``. The
264
only difference compared with the superclass is that this class
265
provides a slightly different repr as well as custom implementations
266
of ``mean`` and ``mm``, scaling the mean by a factor of 10 and
267
returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does.
269
handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL
272
return f"SubDiagonalTensor(N={self._N}, value={self._i})"
275
@implements_sub_diagonal(torch.mean)
276
def sub_diagonal_mean(mat):
277
return 10 * float(mat._i) / mat._N
279
@implements_sub_diagonal(bar)
280
def sub_diagonal_bar(mat):
283
@implements_sub_diagonal(torch.mm)
284
def sub_diagonal_mm(mat1, mat2):
287
@implements_sub_diagonal(torch.div)
288
def sub_diagonal_div(input, other, out=None):
289
return NotImplemented
291
@implements_sub_diagonal(foo)
292
def sub_diagonal_foo(a, b, c=None):
293
return NotImplemented
295
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
296
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
299
# Note: _triggered wrapper
300
# Dict that wraps the implementations from get_testing_overrides into another
301
# function with a _triggered slot/flag. The triggered flag is set when the
302
# implementation is called.
303
WRAPPED_TRIGGERED_IMPLS = {}
306
def triggered_wrapper(f):
308
def wrapped(*args, **kwargs):
309
wrapped._triggered = True
310
return f(*args, **kwargs)
312
wrapped._triggered = False
315
def implements_tensor_like(torch_function):
316
"Register a torch function override for TensorLike"
317
@functools.wraps(torch_function)
319
HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
323
def generate_tensor_like_torch_implementations():
324
torch_vars = vars(torch)
326
testing_overrides = get_testing_overrides()
327
# test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
328
# function sample_functional. Depending on what order you run pytest
329
# collection, this may trigger the error here. This is a hack to fix
330
# the problem. A more proper fix is to make the "not tested" check
331
# a test on its own, and to make sure the monkeypatch is only installed
332
# for the span of the relevant test (and deleted afterwards)
333
testing_ignore = {"sample_functional", "autocast"}
334
for namespace, funcs in get_overridable_functions().items():
336
if func not in testing_overrides and func.__name__ not in testing_ignore:
337
untested_funcs.append(f"{namespace}.{func.__name__}")
339
"The following functions are not tested for __torch_function__ "
340
"support, please ensure there is an entry in the dict returned by "
341
"torch.overrides.get_testing_overrides for this function or if a "
342
"__torch_function__ override does not make sense, add an entry to "
343
"the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
345
assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
346
for func, override in testing_overrides.items():
347
# decorate the overrides with implements_tensor_like if it's not a
348
# torch.Tensor method
349
wrapped = triggered_wrapper(override)
350
# See note: "_triggered wrapper"
351
WRAPPED_TRIGGERED_IMPLS[func] = wrapped
352
if is_tensor_method_or_property(func):
353
implements_sub(func)(wrapped)
355
implements_tensor_like(func)(wrapped)
357
generate_tensor_like_torch_implementations()
360
"""A class that overrides the full torch API
362
This class is used to explicitly test that the full torch.tensor API
363
can be overriden with a class that defines __torch_function__.
366
def __torch_function__(cls, func, types, args=(), kwargs=None):
370
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
371
return NotImplemented
372
# In this case _torch_function_ should override TensorLike objects
373
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
375
class TestTorchFunctionOverride(TestCase):
378
cls._stack = contextlib.ExitStack()
379
if TEST_WITH_TORCHDYNAMO:
380
# Add classes to the wrapped tensor subclasses
381
@contextlib.contextmanager
382
def setup_subclasses():
383
old = set(torch._dynamo.config.traceable_tensor_subclasses)
384
torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
388
torch._dynamo.config.traceable_tensor_subclasses.clear()
389
torch._dynamo.config.traceable_tensor_subclasses.update(old)
391
cls._stack.enter_context(setup_subclasses())
394
def tearDownClass(cls):
397
def test_mean_semantics(self):
398
"""Test that a function with one argument can be overridden"""
399
t1 = DiagonalTensor(5, 2)
400
t2 = SubTensor([[1, 2], [1, 2]])
401
t3 = SubDiagonalTensor(5, 2)
402
self.assertEqual(torch.mean(t1), 0.4)
403
self.assertEqual(bar(t1), -1)
404
self.assertEqual(torch.mean(t2), 0)
405
self.assertEqual(bar(t2), 1)
406
self.assertEqual(torch.mean(t3), 4.0)
407
self.assertEqual(bar(t3), 0)
409
def test_has_torch_function_non_sequence(self):
410
with self.assertRaisesRegex(TypeError, "expected a sequence"):
411
has_torch_function(object())
413
def test_mm_semantics(self):
414
"""Test that a function with multiple arguments can be overridden"""
415
t1 = DiagonalTensor(5, 2)
416
t2 = torch.eye(5) * 2
417
t3 = SubTensor([[1, 2], [1, 2]])
418
t4 = SubDiagonalTensor(5, 2)
419
# only DiagonalTensor so should always get DiagonalTensor result
420
self.assertEqual(torch.mm(t1, t1), 0)
421
# tensor and DiagonalTensor, always return DiagonalTensor result
422
self.assertEqual(torch.mm(t1, t2), 0)
423
self.assertEqual(torch.mm(t2, t1), 0)
424
# only SubTensor so should always get SubTensor result
425
self.assertEqual(torch.mm(t3, t3), -1)
426
# tensor and SubTensor so should always get SubTensor result
427
self.assertEqual(torch.mm(t3, t2), -1)
428
self.assertEqual(torch.mm(t2, t3), -1)
429
# DiagonalTensor and SubTensor are unrelated classes so the result
430
# depends on which argument appears first
431
self.assertEqual(torch.mm(t3, t1), -1)
432
self.assertEqual(torch.mm(t1, t3), 0)
433
# SubDiagonalTensor should take precedence over DiagonalTensor
434
# but should behave otherwise the same as DiagonalTensor
435
self.assertEqual(torch.mm(t4, t4), 1)
436
self.assertEqual(torch.mm(t4, t1), 1)
437
self.assertEqual(torch.mm(t1, t4), 1)
438
self.assertEqual(torch.mm(t4, t2), 1)
439
self.assertEqual(torch.mm(t2, t4), 1)
440
self.assertEqual(torch.mm(t3, t4), -1)
441
self.assertEqual(torch.mm(t4, t3), 1)
443
def test_precedence_semantics(self):
444
"""Test semantics for __torch_function__ for functions that take
447
For functions that take multiple arguments, the appropriate
448
__torch_function__ implementation to call is determined by
449
examining the types of the arguments. The precedence order is
450
left-to-right in the argument list, except subclasses are always
451
checked before superclasses. The first result of calling the
452
implementations in precedence order that is not NotImplemented
453
is returned to the user. If all implementations return
454
NotImplemented, a TypeError is raised.
456
All cases are tested with functions implemented in C++ and
457
either foo or baz, which are python functions defined above that
458
are instrumented to obey the same dispatch rules as the
459
functions in torch.functional.
461
# DiagonalTensor has a valid override and SubDiagonal has an
462
# override that returns NotImplemented so we should call the
463
# DiagonalTensor implementation, returning -1
464
t1 = DiagonalTensor(5, 2)
465
t2 = SubDiagonalTensor(5, 2)
466
self.assertEqual(torch.div(t1, t2), -1)
467
self.assertEqual(torch.div(t2, t1), -1)
468
self.assertEqual(foo(t1, t2), -1)
469
self.assertEqual(foo(t2, t1), -1)
471
# SubTensor has an implementation that returns NotImplemented as
472
# well so it should behave exactly like SubDiagonalTensor in the
474
t3 = SubTensor([[1, 2], [1, 2]])
475
self.assertEqual(torch.div(t1, t3), -1)
476
self.assertEqual(torch.div(t3, t1), -1)
477
self.assertEqual(foo(t1, t3), -1)
478
self.assertEqual(foo(t3, t1), -1)
480
# div between SubTensor and SubDiagonalTensor should raise
481
# TypeError since both have an implementation that
482
# explicitly returns NotImplemented
483
with self.assertRaises(TypeError):
485
with self.assertRaises(TypeError):
487
with self.assertRaises(TypeError):
489
with self.assertRaises(TypeError):
492
# none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a
493
# mul or a baz implementation so all ops should raise TypeError
494
with self.assertRaises(TypeError):
496
with self.assertRaises(TypeError):
498
with self.assertRaises(TypeError):
500
with self.assertRaises(TypeError):
502
with self.assertRaises(TypeError):
504
with self.assertRaises(TypeError):
506
with self.assertRaises(TypeError):
508
with self.assertRaises(TypeError):
510
with self.assertRaises(TypeError):
512
with self.assertRaises(TypeError):
514
with self.assertRaises(TypeError):
516
with self.assertRaises(TypeError):
518
with self.assertRaises(TypeError):
520
with self.assertRaises(TypeError):
522
with self.assertRaises(TypeError):
524
with self.assertRaises(TypeError):
526
with self.assertRaises(TypeError):
528
with self.assertRaises(TypeError):
531
def test_user_implementation_raises(self):
532
"""Test that errors raised in user implementations propagate correctly"""
533
t1 = DiagonalTensor(5, 2)
534
t2 = DiagonalTensor(5, 2)
535
with self.assertRaises(ValueError):
537
with self.assertRaises(ValueError):
540
def test_tensor_subclass_propagation(self):
541
"""this test exercises the functionality described in
542
docs/source/notes/extending.rst#subclassing-torchtensor"""
543
t1 = torch.tensor([5])
544
t2 = torch.tensor([6])
549
ss1 = SubSubTensor2([5])
550
ss2 = SubSubTensor2([6])
552
sn1 = SubTensor3([5])
553
sn2 = SubTensor3([6])
555
# Check that leaf subclass is kept regardless of order
556
self.assertTrue(isinstance(s1 + t2, SubTensor2))
557
self.assertTrue(isinstance(t1 + s2, SubTensor2))
558
self.assertTrue(isinstance(s1 + s2, SubTensor2))
560
# Check indexing subclass is kept
561
self.assertTrue(isinstance(s1[0], SubTensor2))
563
# Check case for subclass of subclass.
564
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
565
self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
566
self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
567
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
568
self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
569
self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
570
self.assertTrue(isinstance(ss1[0], SubSubTensor2))
572
# Make sure unrelated class trees are not merged.
573
with self.assertRaises(TypeError):
575
with self.assertRaises(TypeError):
579
# https://github.com/szagoruyko/pytorchviz/issues/65
580
class DummyTensor(torch.Tensor):
585
self.assertTrue(c._is_view())
586
self.assertTrue(c._base is a)
589
# Previously, Tensor-like objects that did not subclass from Tensor
590
# did not get wrapped into unary tuples before being passed into
591
# handle_torch_function, in contradiction with how Tensor-likes
594
# NB: this asserts that the arguments get normalized into a tuple
595
# before entering the torch function handler; it could go the
596
# other way but beware https://github.com/pytorch/pytorch/issues/76037
600
def __torch_function__(cls, func, types, args=(), kwargs=None):
601
inputs, outputs = args
602
self.assertEqual(inputs, (x,))
603
self.assertEqual(outputs, (x,))
607
self.assertEqual(torch.autograd.grad(x, x), -1)
609
def test_pow_rpow(self):
610
class NothingImplemented(torch.Tensor):
612
def __torch_function__(cls, func, types, args=(), kwargs=None):
613
return NotImplemented
615
class RPowOnly(torch.Tensor):
617
def __torch_function__(cls, func, types, args=(), kwargs=None):
618
if func is torch.Tensor.__rpow__:
620
return NotImplemented
622
self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
625
def generate_tensor_like_override_tests(cls):
626
from torch.testing._internal.generated.annotated_fn_args import annotated_args
628
def test_generator(func, override):
629
# If func corresponds to a torch.Tensor method or property.
630
if is_tensor_method_or_property(func):
631
# Generate an instance by using SubTensor,
633
return SubTensor([5])
635
# Otherwise, TensorLike.
639
# FIXME The following code does not support kwonly args without defaults.
640
# The fix is easy, as one just needs to save these args when generating the variable
641
# annotated_args. The problem is that, if one does so, one finds a number
642
# of functions that have problematic signatures in native_functions.yaml.
643
# Fixing these would be BC breaking, so hence this terrible hack
644
# https://github.com/pytorch/pytorch/issues/67008
646
if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
647
kwargs = {"upper": True}
650
is_method = is_tensor_method_or_property(func)
652
def _simple_type_parser(func, arg_name, arg_type):
653
# Guess valid input to aten function based on type of argument
654
if arg_type == "Tensor":
655
return instance_gen()
656
elif arg_type == "TensorList" or arg_type == "ITensorListRef":
657
return [instance_gen(), instance_gen()]
658
elif arg_type == "c10::List<::std::optional<Tensor>>":
659
return [instance_gen(), instance_gen()]
660
elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
661
size = arg.get("size", 2)
666
elif arg_type == "Scalar":
668
elif arg_type == "bool":
670
elif arg_type == "Dimname":
672
elif arg_type == "DimnameList":
674
elif arg_type.startswith("int"):
676
elif arg_type in {"Stream"}:
677
return torch.Stream()
678
elif arg_type.startswith("float") or arg_type == "double":
680
elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}:
682
elif arg_type == "ScalarType":
684
elif arg_type == "c10::string_view":
686
elif arg_type == "SymInt":
687
# TODO: generate actual SymbolicInt
691
f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
694
if func in annotated_args:
695
for arg in annotated_args[func]:
696
# Guess valid input to aten function based on type of argument
697
t = arg["simple_type"]
700
if t == "Tensor" and is_method and arg["name"] == "self":
701
# See "Note: properties and __get__"
702
func = func.__get__(instance_gen())
704
arg_to_add = _simple_type_parser(func, arg["name"], t)
705
if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True):
706
kwargs[arg["name"]] = arg_to_add
708
func_args.append(arg_to_add)
710
args = inspect.getfullargspec(override)
712
func_args = inspect.getfullargspec(func)
713
# Remove annotations from argspec
714
func_args = type(func_args)(**{**func_args, 'annotations': None})
715
if func_args != args:
716
raise RuntimeError(f"Override for {func} doesn't match its argspec.\n"
717
+ f"Original: {inspect.signature(func)}\n"
718
+ f"Override: {inspect.signature(override)}")
721
nargs = len(args.args)
722
if args.defaults is not None:
723
nargs -= len(args.defaults)
724
func_args = [instance_gen() for _ in range(nargs)]
725
if args.varargs is not None:
726
func_args += [instance_gen(), instance_gen()]
729
ret = func(*func_args, **kwargs)
730
# ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
731
# This is currently the best check but doesn't work for, for example,
732
# Tensor.__add__ because it redirects to Tensor.add.
733
# See note "_triggered wrapper"
734
if not is_method or ret is None:
735
self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
738
self.assertEqual(ret, -1)
742
for func, override in get_testing_overrides().items():
743
test_method = test_generator(func, override)
744
if func.__name__ == "__get__":
745
# Note: properties and __get__
746
# __get__ is part of the descriptor protocol.
747
# https://docs.python.org/3/howto/descriptor.html
748
# This is used for properties of the form
749
# torch.Tensor.<property>, with the method __get__
750
# In this case we get the property name in two ways:
752
# This case for properties defined in C.
759
# This one for properties defined in Python.
761
module = "Tensor." + func.__self__.fget.__name__
763
# Unfortunately I couldn't find a way to unify these two cases
764
# and there is no way for general descriptors.
765
elif is_tensor_method_or_property(func):
768
module = func.__module__
770
name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
772
name = f'test_{func.__name__}'
773
test_method.__name__ = name
774
setattr(cls, name, test_method)
776
generate_tensor_like_override_tests(TestTorchFunctionOverride)
779
"Basic data container that knows how to unwrap itself"
780
def __init__(self, data):
781
self.__dict__["_data"] = data
782
self.__dict__["used_attrs"] = set()
783
self.__dict__["used_calls"] = set()
785
def __getattr__(self, name):
786
if name in self.__dict__:
787
return self.__dict__[name]
788
self.used_attrs.add(name)
790
val = getattr(self._data, name)
793
if not isinstance(val, torch.device) and callable(val):
794
c = getattr(type(self._data), name)
795
# Don't append self to args if classmethod/staticmethod
797
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
798
# Otherwise append self to args
799
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
803
def __setattr__(self, name, value):
804
if name in self.__dict__:
805
self.__dict__[name] = value
807
self.used_attrs.add(name)
808
setattr(self._data, name, unwrap(value))
810
def __setitem__(self, key, value):
811
self._data[unwrap(key)] = unwrap(value)
813
def __getitem__(self, key):
814
return wrap(self._data[unwrap(key)])
817
def __torch_function__(cls, func, types, args=(), kwargs=None):
820
# Find an instance of this class in the arguments
821
args_of_this_cls = []
823
if isinstance(a, cls):
824
args_of_this_cls.append(a)
825
elif isinstance(a, collections.abc.Sequence):
826
args_of_this_cls.extend(el for el in a if isinstance(el, cls))
827
assert len(args_of_this_cls) > 0
828
for a in args_of_this_cls:
829
a.used_calls.add(func)
830
args = unwrap(tuple(args))
831
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
833
return wrap(func(*args, **kwargs))
835
def __add__(self, other):
836
return self.__torch_function__(torch.add, (Wrapper,), (self, other))
838
def __mul__(self, other):
839
return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
841
def __sub__(self, other):
842
return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
844
def __truediv__(self, other):
845
return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
847
def __floordiv__(self, other):
848
return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
850
def __ge__(self, other):
851
return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
853
def __gt__(self, other):
854
return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
856
def __lt__(self, other):
857
return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
859
def __le__(self, other):
860
return self.__torch_function__(torch.le, (Wrapper,), (self, other))
862
def __eq__(self, other):
863
return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
865
def __ne__(self, other):
866
return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
869
return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
872
return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
875
return len(self._data)
878
# unwrap inputs if necessary
880
if type(v) in {tuple, list}:
881
return type(v)(unwrap(vi) for vi in v)
883
return v._data if isinstance(v, Wrapper) else v
885
# wrap inputs if necessary
887
if type(v) in {tuple, list}:
888
return type(v)(wrap(vi) for vi in v)
890
return Wrapper(v) if isinstance(v, torch.Tensor) else v
892
class TestEinsumOverride(TestCase):
893
"Regression test for gh-38479"
894
def test_wrapper(self):
895
x = Wrapper(torch.randn(5))
896
y = Wrapper(torch.randn(4))
897
self.assertEqual(torch.einsum('i,j->ij', x, y)._data,
898
torch.ger(x, y)._data)
900
# in the old einsum interface, `operands` is a list
901
a = Wrapper(torch.randn(2, 3))
902
b = Wrapper(torch.randn(5, 3, 7))
903
c = Wrapper(torch.randn(2, 7))
904
self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data,
905
torch.nn.functional.bilinear(a, c, b)._data)
907
class TestGradCheckOverride(TestCase):
908
"Test that wrappers work with gradcheck."
909
def test_gradcheck(self):
910
from torch.testing._internal.common_utils import gradcheck, gradgradcheck
912
def run_test(fast_mode):
913
a = wrap(torch.tensor(5.0, dtype=torch.double))
914
b = wrap(torch.tensor(6.0, dtype=torch.double))
916
a.requires_grad = True
917
b.requires_grad = True
919
gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
920
gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
922
total_used_attrs = a.used_attrs.union(b.used_attrs)
923
total_used_calls = a.used_calls.union(b.used_calls)
925
# These attributes (and the functions below) may change
926
# if the gradcheck implementation changes. It's best to
927
# aim for attributes that may be commonly present on other
929
expected_used_attrs = {
943
expected_used_attrs.add('is_complex')
944
expected_used_attrs.add('device')
945
self.assertEqual(expected_used_attrs, total_used_attrs)
947
expected_used_calls = {
948
torch.Tensor.new_zeros,
950
torch.Tensor.is_floating_point,
953
torch.Tensor.requires_grad_,
958
expected_used_calls.add(torch.Tensor.is_complex)
959
self.assertEqual(expected_used_calls, total_used_calls)
960
run_test(fast_mode=True)
961
run_test(fast_mode=False)
963
class TestNamedTuple(TestCase):
964
""" Regression test for gh-47090 """
966
x = torch.tensor([1, 2])
967
xs = x.as_subclass(SubTensor2)
968
r = torch.max(x, dim=0)
969
rs = torch.max(xs, dim=0)
970
self.assertEqual(type(r), type(rs))
971
self.assertEqual(r, rs)
973
class TestGradNewOnesOverride(TestCase):
974
""" Regression test for gh-47069 """
975
def test_newones(self):
976
t = torch.tensor([1, 2]).as_subclass(SubTensor2)
977
n = t.new_ones((1, 2))
978
self.assertEqual(type(n), SubTensor2)
980
class TestPickle(TestCase):
981
"Regression test for gh-47051"
982
def test_pickle(self):
983
t = torch.tensor([1]).as_subclass(SubTensor2)
985
t2 = pickle.loads(pickle.dumps(t))
986
self.assertIs(type(t2), SubTensor2)
987
self.assertEqual(t2.abcd, "e")
989
class TestBroadcastAllOverride(TestCase):
990
""" test for gh-37141 """
991
def test_broadcast_all(self):
992
from torch.distributions.utils import broadcast_all
993
a = torch.tensor([1.2, 3.4, 5.6])
995
b = torch.tensor(5.0)
997
c = torch.tensor([5.0, 5.0, 5.0])
999
o_1 = broadcast_all(a_w, b_w)
1000
self.assertTrue(isinstance(o_1[0], Wrapper))
1001
self.assertTrue(isinstance(o_1[1], Wrapper))
1002
self.assertEqual(o_1[0]._data, a)
1003
self.assertEqual(o_1[1]._data, c)
1005
o_2 = broadcast_all(a_w, b)
1006
self.assertTrue(isinstance(o_2[0], Wrapper))
1007
self.assertTrue(isinstance(o_2[1], Wrapper))
1008
self.assertEqual(o_2[0]._data, a)
1009
self.assertEqual(o_2[1]._data, c)
1011
class TestWrapTorchFunction(TestCase):
1012
def test_wrap_torch_function(self):
1015
def __torch_function__(cls, func, types, args, kwargs):
1021
@torch.overrides.wrap_torch_function(dispatcher)
1025
self.assertEqual(f(A()), -1)
1027
class TestIndexing(TestCase):
1028
""" Regression tests for gh-46277 """
1029
def test_getitem(self):
1032
def __torch_function__(cls, func, types, args, kwargs=None):
1035
t = torch.tensor([5])
1036
self.assertEqual(t[A()], -1)
1037
self.assertEqual(t, torch.tensor([5]))
1039
def test_getitem_subclass(self):
1040
class A(torch.Tensor):
1042
def __torch_function__(cls, func, types, args, kwargs=None):
1045
t = torch.tensor([5])
1046
self.assertEqual(t[A()], -1)
1047
self.assertEqual(t[5, A()], -1)
1048
self.assertEqual(t, torch.tensor([5]))
1050
def test_setitem(self):
1055
def __torch_function__(cls, func, types, args, kwargs=None):
1059
t = torch.tensor([5])
1062
self.assertIn(Tensor.__setitem__, triggered)
1063
self.assertEqual(t, torch.tensor([5]))
1065
def test_setitem_val(self):
1070
def __torch_function__(cls, func, types, args, kwargs=None):
1074
t = torch.tensor([5])
1076
self.assertIn(Tensor.__setitem__, triggered)
1077
self.assertEqual(t, torch.tensor([5]))
1079
def test_setitem_subclass(self):
1082
class A(torch.Tensor):
1084
def __torch_function__(cls, func, types, args, kwargs=None):
1088
t = torch.tensor([5])
1091
self.assertIn(Tensor.__setitem__, triggered)
1092
self.assertEqual(t, torch.tensor([5]))
1095
class TestIterator(TestCase):
1096
# Regression test for gh-54457
1097
def test_iterator(self):
1098
t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2)
1100
self.assertIs(type(next(it)), SubTensor2)
1101
self.assertIs(type(next(it)), SubTensor2)
1102
self.assertIs(type(next(it)), SubTensor2)
1105
class TestRNN(TestCase):
1106
# Regression test for gh-55868
1108
model = torch.nn.RNN(10, 20, 2)
1109
input = Wrapper(torch.randn(1, 5, 10))
1113
class TestDisabledTorchFunction(TestCase):
1114
# Regression test for gh-64687
1115
def test_parameter_does_not_prevent_dispatch(self):
1118
def __torch_function__(cls, func, types, args=(), kwargs=None):
1122
t2 = torch.nn.Parameter(torch.rand(2, 2))
1123
self.assertEqual(torch.add(t2, t1), "called")
1125
inp = torch.rand(10, 10)
1126
self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
1127
self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
1129
class TestResolveName(TestCase):
1130
def test_resolve_name(self):
1131
for cs in get_overridable_functions().values():
1134
eval(torch.overrides.resolve_name(c)),
1136
msg=f"{c}, {torch.overrides.resolve_name(c)}"
1139
class TestTorchFunctionWarning(TestCase):
1140
def test_warn_on_invalid_torch_function_standalone_class(self):
1141
class StandaloneTorchFunctionClass:
1142
def __torch_function__(self, *args, **kwargs):
1144
a = StandaloneTorchFunctionClass()
1145
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1146
# Function that handles torch_function on the python side
1147
torch.nn.functional.dropout(a)
1148
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1149
# Function that handles torch_function in C++
1152
def test_warn_on_invalid_torch_function_tensor_subclass(self):
1153
class TensorSubclassTorchFunctionClass(torch.Tensor):
1154
def __torch_function__(self, *args, **kwargs):
1156
b = TensorSubclassTorchFunctionClass()
1157
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1158
# Function that handles torch_function on the python side
1159
torch.nn.functional.dropout(b)
1160
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1161
# Function that handles torch_function in C++
1164
class TestDisabledUserWarnings(TestCase):
1165
def test_no_implicit_user_warning_for_deprecated_functions(self):
1166
self.assertNotWarn(get_ignored_functions)
1167
self.assertNotWarn(get_testing_overrides)
1168
self.assertNotWarn(get_overridable_functions)
1169
self.assertNotWarn(lambda: resolve_name(torch.Tensor.add))
1170
self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add))
1172
@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref")
1173
class TestTorchFunctionMode(TestCase):
1174
def test_basic(self):
1175
class A(TorchFunctionMode):
1176
def __torch_function__(self, *args, **kwargs):
1178
# NB: factory functions get overridden too!
1181
self.assertEqual(torch.randn(3), -1)
1182
self.assertEqual(torch.add(x, x), -1)
1183
self.assertEqual(torch.split(None, [2]), -1) # python side
1184
self.assertEqual(bar(x), -1)
1186
def test_factory_override(self):
1187
class A(TorchFunctionMode):
1188
def __torch_function__(self, *args, **kwargs):
1192
self.assertEqual(torch.tensor([1]), -1)
1193
self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
1194
self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
1195
self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1)
1196
self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1)
1197
self.assertEqual(torch.as_tensor([1]), -1)
1199
def test_modes_handle_first(self):
1200
class A(TorchFunctionMode):
1201
def __torch_function__(self, *args, **kwargs):
1206
self.assertEqual(torch.neg(x), -40)
1207
self.assertEqual(torch.mean(x), -40)
1208
self.assertEqual(torch.mm(x, x), -40)
1209
self.assertEqual(bar(x), -40)
1211
def test_modes_return_notimplemented(self):
1212
class MyMode(TorchFunctionMode):
1213
def __torch_function__(self, *args, **kwargs):
1214
return NotImplemented
1218
self.assertEqual(torch.mean(x), 0)
1219
self.assertEqual(torch.mm(x, x), -1)
1220
self.assertEqual(bar(x), 1)
1221
self.assertRaisesRegex(
1222
TypeError, r'SubTensor',
1223
lambda: self.assertEqual(torch.max(x, x)))
1225
def test_with_mode(self):
1226
class ErrorA(RuntimeError):
1229
class A(TorchFunctionMode):
1230
def __torch_function__(self, *args, **kwargs):
1233
with self.assertRaises(ErrorA):
1237
def test_with_mode_created_separately(self):
1238
class ErrorA(RuntimeError):
1241
class A(TorchFunctionMode):
1242
def __torch_function__(self, *args, **kwargs):
1246
with self.assertRaises(ErrorA):
1250
def test_with_nested_modes(self):
1253
class A(TorchFunctionMode):
1254
def __init__(self, msg):
1257
def __torch_function__(self, func, _, args=(), kwargs=None):
1260
out.append(self.msg)
1261
return func(*args, **kwargs)
1267
self.assertEqual(out, ["layer2", "layer1"])
1269
def test_nested_same_mode(self):
1272
class A(TorchFunctionMode):
1273
def __init__(self, msg):
1276
def __torch_function__(self, func, _, args=(), kwargs=None):
1279
out.append(self.msg)
1280
return func(*args, **kwargs)
1282
with A("layer1") as a:
1286
self.assertEqual(out, ["layer1", "layer1"])
1288
def test_error_using_class_method_on_mode(self):
1289
class A(TorchFunctionMode):
1291
def __torch_function__(cls, func, _, args=(), kwargs=None):
1292
return func(args, kwargs)
1294
x = torch.tensor(5.)
1295
with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"):
1299
def test_restacking_with_ancestor(self):
1300
class A(TorchFunctionMode):
1310
def test_get_cur_mode(self):
1311
class A(TorchFunctionMode):
1312
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1316
self.assertEqual(_get_current_function_mode(), mode1)
1320
self.assertEqual(_get_current_function_mode(), mode2)
1323
def test_get_mode_stack(self):
1324
class A(TorchFunctionMode):
1325
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1328
self.assertEqual(_get_current_function_mode_stack(), [])
1331
self.assertEqual(_get_current_function_mode_stack(), [mode1])
1335
self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
1337
def test_all_same_mode(self):
1338
class A(TorchFunctionMode):
1343
self.assertTrue(all_same_mode([x, x, x]))
1344
self.assertFalse(all_same_mode([x, None]))
1345
self.assertFalse(all_same_mode([x, y]))
1347
def test_nested_modes_with_python_has_torch_function(self):
1350
class A(TorchFunctionMode):
1351
def __torch_function__(self, func, types, args=(), kwargs=None):
1353
kwargs = {} if kwargs is None else kwargs
1354
return func(*args, **kwargs)
1356
class B(TorchFunctionMode):
1357
def __torch_function__(self, func, types, args=(), kwargs=None):
1359
kwargs = {} if kwargs is None else kwargs
1360
return func(*args, **kwargs)
1362
x = torch.randn(3, 4)
1367
self.assertEqual(y, x)
1368
self.assertEqual(called, ["B", "A"])
1371
def test_reentrant_mode_idiom(self):
1374
class A(TorchFunctionMode):
1375
def __torch_function__(self, func, types, args=(), kwargs=None):
1379
if func is torch.sub:
1383
return torch.add(input, other, alpha=-1)
1384
return func(*args, **kwargs)
1390
# add hits the torch function again!
1391
self.assertEqual(log, [torch.sub, torch.add])
1393
def test_nn_parse_to(self):
1394
# This failed because the parser thinks the function is called to()
1395
# but it's actually called _parse_to()
1399
class A(TorchFunctionMode):
1400
def __torch_function__(self, func, types, args=(), kwargs=None):
1405
return func(*args, **kwargs)
1408
torch._C._nn._parse_to('cpu')
1410
self.assertTrue(called)
1412
def test_getitem_call(self):
1413
# This failed because the parser thinks the function is called to()
1414
# but it's actually called _parse_to()
1418
class A(TorchFunctionMode):
1419
def __torch_function__(self, func, types, args=(), kwargs=None):
1424
return func(*args, **kwargs)
1431
self.assertTrue(called)
1434
def test_distributions_bernoulli(self):
1435
# This failed because improper use of has_torch_function when
1436
# is_tensor_like should have been used instead, inside the
1437
# broadcasting logic called by distributions (Bernoulli doesn't
1442
class A(TorchFunctionMode):
1443
def __torch_function__(self, func, types, args=(), kwargs=None):
1448
return func(*args, **kwargs)
1451
torch.distributions.Bernoulli(0.3)
1453
self.assertTrue(called)
1455
def test_mode_notimplemented_loop(self):
1456
# Default tensor subclass implementation disables torch function;
1457
# when we redispatch to mode we must not treat the objects as
1462
class A(TorchFunctionMode):
1463
def __torch_function__(self, func, types, args=(), kwargs=None):
1468
# The first time we call, the mode sees an active type that
1469
# it doesn't know how to deal with. The second time, we're
1470
# instructed to treat it "as if it were a tensor", and so
1471
# we keep going. I'm not entirely clear if the subclasses
1472
# disappearing from types is the correct way to do it.
1473
if any(t is not torch.Tensor for t in types):
1474
return NotImplemented
1476
return func(*args, **kwargs)
1478
class B(torch.Tensor):
1486
self.assertIs(type(r), B)
1487
self.assertEqual(called, 2)
1494
self.assertIs(type(r), B)
1495
self.assertEqual(called, 2)
1497
def test_disable_subclass_not_mode(self):
1500
class A(TorchFunctionMode):
1501
def __torch_function__(self, func, types, args=(), kwargs=None):
1506
return func(*args, **kwargs)
1508
class B(torch.Tensor):
1511
x = B(torch.randn(5))
1513
with torch._C.DisableTorchFunctionSubclass():
1514
self.assertNotIsInstance(torch.sum(x), B)
1516
self.assertTrue(called)
1518
def test_disable_subclass_mode(self):
1521
class A(TorchFunctionMode):
1522
def __torch_function__(self, func, types, args=(), kwargs=None):
1527
return func(*args, **kwargs)
1529
class B(torch.Tensor):
1532
x = B(torch.randn(5))
1534
with torch._C.DisableTorchFunction():
1535
self.assertNotIsInstance(torch.sum(x), B)
1537
self.assertFalse(called)
1539
def test_disable_enable_subclass(self):
1542
class A(torch.Tensor):
1545
x = A(torch.randn(5))
1546
with torch._C.DisableTorchFunctionSubclass():
1547
g = torch._C._EnableTorchFunction()
1549
self.assertIsInstance(torch.sum(x), A)
1553
def test_torch_function_all_disabled_api(self):
1554
from torch._C import _is_torch_function_all_disabled
1556
state = _is_torch_function_all_disabled()
1557
self.assertFalse(state)
1559
with torch._C.DisableTorchFunction():
1560
state = _is_torch_function_all_disabled()
1561
self.assertTrue(state)
1563
state = _is_torch_function_all_disabled()
1564
self.assertFalse(state)
1566
with torch._C.DisableTorchFunctionSubclass():
1567
state = _is_torch_function_all_disabled()
1568
self.assertFalse(state)
1570
def test_subclass_hash(self):
1571
class DiagTensor(torch.Tensor):
1572
def __init__(self, diag):
1576
def __torch_function__(cls, func, types, args=(), kwargs=None):
1577
kwargs = kwargs or {}
1579
def get_full_matrices(t):
1580
if isinstance(t, DiagTensor):
1581
return torch.diag_embed(t._diag)
1585
return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs))
1590
self.assertEqual((a + 1), torch.diag_embed(d) + 1)
1592
# If the hash function was returning the same value, this would
1593
# fail inside `Tensor.__eq__`.
1594
# If __hash__ was going through torch_function, the implementation above would
1595
# be wrong as it would compute the hash on a temporary Tensor thus not ensuring
1596
# the uniqueness of the hash that we rely on for Tensors.
1599
s.add(DiagTensor(d))
1601
def test_custom_device_type(self):
1602
class CustomDeviceContext(TorchFunctionMode):
1604
def __torch_function__(self, func, types, args=(), kwargs=None):
1605
kwargs = kwargs or {}
1606
if func == torch.device:
1607
if args and isinstance(args[0], int):
1608
args = ("xla", args[0])
1609
elif isinstance(kwargs.get('device'), int):
1610
kwargs['device'] = f"xla:{kwargs.get('device')}"
1611
return func(*args, **kwargs)
1613
with CustomDeviceContext():
1614
d_args = torch.device(0)
1615
self.assertEqual(d_args.type, "xla")
1616
self.assertEqual(d_args.index, 0)
1617
d_kwargs = torch.device(device=0)
1618
self.assertEqual(d_kwargs.type, "xla")
1619
self.assertEqual(d_kwargs.index, 0)
1621
def test_device_context_semantics(self):
1622
from torch._C import _len_torch_function_stack
1623
from torch.utils._device import DeviceContext
1625
torch.set_default_device("cuda")
1628
return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())]
1630
base_mode = BaseTorchFunctionMode()
1632
torch.set_default_device("cpu")
1633
x = torch.ones(2, 2)
1635
self.assertIsInstance(stack[0], DeviceContext)
1636
self.assertEqual(stack[0].device, torch.device("cpu"))
1639
self.assertIsInstance(stack[0], DeviceContext)
1640
self.assertEqual(stack[0].device, torch.device("cpu"))
1642
torch.set_default_device(None)
1648
if __name__ == '__main__':