12
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF
13
from torch.overrides import (
14
handle_torch_function,
16
get_ignored_functions,
17
get_overridable_functions,
18
get_testing_overrides,
20
is_tensor_method_or_property,
22
_get_current_function_mode,
23
_get_current_function_mode_stack,
25
from torch.utils._mode_utils import all_same_mode
26
from torch.utils._pytree import tree_map
38
"""A function multiple arguments and an optional argument"""
39
if has_torch_function((a, b, c)):
40
return handle_torch_function(foo, (a, b, c), a, b, c=c)
46
"""A function with one argument"""
47
if has_torch_function((a,)):
48
return handle_torch_function(bar, (a,), a)
52
"""A function with multiple arguments"""
53
if has_torch_function((a, b)):
54
return handle_torch_function(baz, (a, b), a, b)
58
"""Used to test that errors raised in user implementations get propagated"""
59
if has_torch_function((a,)):
60
return handle_torch_function(quux, (a,), a)
71
HANDLED_FUNCTIONS_DIAGONAL = {}
73
def implements_diagonal(torch_function):
74
"""Register a torch function override for DiagonalTensor.
76
This decorator takes a function in the torch API as a
77
parameter. Applying this decorator to a function adds that function
78
as the registered override for the torch function passed as a
79
parameter to the decorator. See DiagonalTensor.__torch_function__
80
for the runtime dispatch implementation and the decorated functions
81
immediately below DiagonalTensor for usage examples.
83
@functools.wraps(torch_function)
85
HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
90
"""A class with __torch_function__ and a specific diagonal representation
92
This class has limited utility and is mostly useful for verifying that the
93
dispatch mechanism works as expected. It is based on the `DiagonalArray
94
example`_ in the NumPy documentation.
96
Note that this class does *not* inherit from ``torch.tensor``, interaction
97
with the pytorch dispatch system happens via the ``__torch_function__``
100
``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has
101
diagonal entries set to *value* and all other entries set to zero. The
102
main functionality of ``DiagonalTensor`` is to provide a more compact
103
string representation of a diagonal tensor than in the base tensor class:
105
>>> d = DiagonalTensor(5, 2)
107
DiagonalTensor(N=5, value=2)
109
tensor([[2., 0., 0., 0., 0.],
110
[0., 2., 0., 0., 0.],
111
[0., 0., 2., 0., 0.],
112
[0., 0., 0., 2., 0.],
113
[0., 0., 0., 0., 2.]])
115
Note that to simplify testing, matrix multiplication of ``DiagonalTensor``
121
.. _DiagonalArray example:
122
https://numpy.org/devdocs/user/basics.dispatch.html
127
handled_functions = HANDLED_FUNCTIONS_DIAGONAL
129
def __init__(self, N, value):
134
return f"DiagonalTensor(N={self._N}, value={self._i})"
137
return self._i * np.eye(self._N)
140
return self._i * torch.eye(self._N)
143
def __torch_function__(cls, func, types, args=(), kwargs=None):
146
if func not in cls.handled_functions:
147
return NotImplemented
148
return cls.handled_functions[func](*args, **kwargs)
150
def __eq__(self, other):
151
if type(other) is type(self):
152
if self._N == other._N and self._i == other._i:
159
@implements_diagonal(torch.mean)
161
return float(mat._i) / mat._N
163
@implements_diagonal(torch.mm)
164
def diagonal_mm(mat1, mat2):
167
@implements_diagonal(torch.div)
168
def diagonal_div(input, other, out=None):
171
@implements_diagonal(torch.add)
175
@implements_diagonal(foo)
176
def diagonal_foo(a, b, c=None):
179
@implements_diagonal(bar)
183
@implements_diagonal(quux)
188
HANDLED_FUNCTIONS_SUB = {}
190
def implements_sub(torch_function):
191
"Register a torch function override for SubTensor"
192
@functools.wraps(torch_function)
194
HANDLED_FUNCTIONS_SUB[torch_function] = func
198
class SubTensor(torch.Tensor):
199
"""A subclass of torch.Tensor use for testing __torch_function__ dispatch
201
This class has the property that matrix multiplication returns zero:
203
>>> s = SubTensor([[1, 1], [1, 1]])
206
>>> t = torch.tensor([[1, 1], [1, 1]])
215
This is useful for testing that the semantics for overriding torch
216
functions are working correctly.
219
def __torch_function__(cls, func, types, args=(), kwargs=None):
223
if func not in HANDLED_FUNCTIONS_SUB:
224
return NotImplemented
225
return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)
227
class SubTensor2(torch.Tensor):
230
class SubSubTensor2(SubTensor2):
233
class SubTensor3(torch.Tensor):
236
@implements_sub(torch.mean)
240
@implements_sub(torch.mm)
241
def sub_mm(mat1, mat2):
248
@implements_sub(torch.div)
249
def sub_div(input, other, out=None):
250
return NotImplemented
253
HANDLED_FUNCTIONS_SUB_DIAGONAL = {}
255
def implements_sub_diagonal(torch_function):
256
"Register a torch function override for SubDiagonalTensor"
257
@functools.wraps(torch_function)
259
HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
263
class SubDiagonalTensor(DiagonalTensor):
264
"""A subclass of ``DiagonalTensor`` to test custom dispatch
266
This class tests semantics for defining ``__torch_function__`` on a
267
subclass of another class that defines ``__torch_function__``. The
268
only difference compared with the superclass is that this class
269
provides a slightly different repr as well as custom implementations
270
of ``mean`` and ``mm``, scaling the mean by a factor of 10 and
271
returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does.
273
handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL
276
return f"SubDiagonalTensor(N={self._N}, value={self._i})"
279
@implements_sub_diagonal(torch.mean)
280
def sub_diagonal_mean(mat):
281
return 10 * float(mat._i) / mat._N
283
@implements_sub_diagonal(bar)
284
def sub_diagonal_bar(mat):
287
@implements_sub_diagonal(torch.mm)
288
def sub_diagonal_mm(mat1, mat2):
291
@implements_sub_diagonal(torch.div)
292
def sub_diagonal_div(input, other, out=None):
293
return NotImplemented
295
@implements_sub_diagonal(foo)
296
def sub_diagonal_foo(a, b, c=None):
297
return NotImplemented
300
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
307
WRAPPED_TRIGGERED_IMPLS = {}
310
def triggered_wrapper(f):
312
def wrapped(*args, **kwargs):
313
wrapped._triggered = True
314
return f(*args, **kwargs)
316
wrapped._triggered = False
319
def implements_tensor_like(torch_function):
320
"Register a torch function override for TensorLike"
321
@functools.wraps(torch_function)
323
HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
327
def generate_tensor_like_torch_implementations():
328
torch_vars = vars(torch)
330
testing_overrides = get_testing_overrides()
337
testing_ignore = {"sample_functional", "autocast"}
338
for namespace, funcs in get_overridable_functions().items():
340
if func not in testing_overrides and func.__name__ not in testing_ignore:
341
untested_funcs.append(f"{namespace}.{func.__name__}")
343
"The following functions are not tested for __torch_function__ "
344
"support, please ensure there is an entry in the dict returned by "
345
"torch.overrides.get_testing_overrides for this function or if a "
346
"__torch_function__ override does not make sense, add an entry to "
347
"the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
349
assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
350
for func, override in testing_overrides.items():
353
wrapped = triggered_wrapper(override)
355
WRAPPED_TRIGGERED_IMPLS[func] = wrapped
356
if is_tensor_method_or_property(func):
357
implements_sub(func)(wrapped)
359
implements_tensor_like(func)(wrapped)
361
generate_tensor_like_torch_implementations()
364
"""A class that overrides the full torch API
366
This class is used to explicitly test that the full torch.tensor API
367
can be overriden with a class that defines __torch_function__.
370
def __torch_function__(cls, func, types, args=(), kwargs=None):
374
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
375
return NotImplemented
377
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
379
class TestTorchFunctionOverride(TestCase):
380
def test_mean_semantics(self):
381
"""Test that a function with one argument can be overrided"""
382
t1 = DiagonalTensor(5, 2)
383
t2 = SubTensor([[1, 2], [1, 2]])
384
t3 = SubDiagonalTensor(5, 2)
385
self.assertEqual(torch.mean(t1), 0.4)
386
self.assertEqual(bar(t1), -1)
387
self.assertEqual(torch.mean(t2), 0)
388
self.assertEqual(bar(t2), 1)
389
self.assertEqual(torch.mean(t3), 4.0)
390
self.assertEqual(bar(t3), 0)
392
def test_has_torch_function_non_sequence(self):
393
with self.assertRaisesRegex(TypeError, "expected a sequence"):
394
has_torch_function(object())
396
def test_mm_semantics(self):
397
"""Test that a function with multiple arguments can be overrided"""
398
t1 = DiagonalTensor(5, 2)
399
t2 = torch.eye(5) * 2
400
t3 = SubTensor([[1, 2], [1, 2]])
401
t4 = SubDiagonalTensor(5, 2)
403
self.assertEqual(torch.mm(t1, t1), 0)
405
self.assertEqual(torch.mm(t1, t2), 0)
406
self.assertEqual(torch.mm(t2, t1), 0)
408
self.assertEqual(torch.mm(t3, t3), -1)
410
self.assertEqual(torch.mm(t3, t2), -1)
411
self.assertEqual(torch.mm(t2, t3), -1)
414
self.assertEqual(torch.mm(t3, t1), -1)
415
self.assertEqual(torch.mm(t1, t3), 0)
418
self.assertEqual(torch.mm(t4, t4), 1)
419
self.assertEqual(torch.mm(t4, t1), 1)
420
self.assertEqual(torch.mm(t1, t4), 1)
421
self.assertEqual(torch.mm(t4, t2), 1)
422
self.assertEqual(torch.mm(t2, t4), 1)
423
self.assertEqual(torch.mm(t3, t4), -1)
424
self.assertEqual(torch.mm(t4, t3), 1)
426
def test_precedence_semantics(self):
427
"""Test semantics for __torch_function__ for functions that take
430
For functions that take multiple arguments, the appropriate
431
__torch_function__ implementation to call is determined by
432
examining the types of the arguments. The precedence order is
433
left-to-right in the argument list, except subclasses are always
434
checked before superclasses. The first result of calling the
435
implementations in precedence order that is not NotImplemented
436
is returned to the user. If all implementations return
437
NotImplemented, a TypeError is raised.
439
All cases are tested with functions implemented in C++ and
440
either foo or baz, which are python functions defined above that
441
are instrumented to obey the same dispatch rules as the
442
functions in torch.functional.
447
t1 = DiagonalTensor(5, 2)
448
t2 = SubDiagonalTensor(5, 2)
449
self.assertEqual(torch.div(t1, t2), -1)
450
self.assertEqual(torch.div(t2, t1), -1)
451
self.assertEqual(foo(t1, t2), -1)
452
self.assertEqual(foo(t2, t1), -1)
457
t3 = SubTensor([[1, 2], [1, 2]])
458
self.assertEqual(torch.div(t1, t3), -1)
459
self.assertEqual(torch.div(t3, t1), -1)
460
self.assertEqual(foo(t1, t3), -1)
461
self.assertEqual(foo(t3, t1), -1)
466
with self.assertRaises(TypeError):
468
with self.assertRaises(TypeError):
470
with self.assertRaises(TypeError):
472
with self.assertRaises(TypeError):
477
with self.assertRaises(TypeError):
479
with self.assertRaises(TypeError):
481
with self.assertRaises(TypeError):
483
with self.assertRaises(TypeError):
485
with self.assertRaises(TypeError):
487
with self.assertRaises(TypeError):
489
with self.assertRaises(TypeError):
491
with self.assertRaises(TypeError):
493
with self.assertRaises(TypeError):
495
with self.assertRaises(TypeError):
497
with self.assertRaises(TypeError):
499
with self.assertRaises(TypeError):
501
with self.assertRaises(TypeError):
503
with self.assertRaises(TypeError):
505
with self.assertRaises(TypeError):
507
with self.assertRaises(TypeError):
509
with self.assertRaises(TypeError):
511
with self.assertRaises(TypeError):
514
def test_user_implementation_raises(self):
515
"""Test that errors raised in user implementations propagate correctly"""
516
t1 = DiagonalTensor(5, 2)
517
t2 = DiagonalTensor(5, 2)
518
with self.assertRaises(ValueError):
520
with self.assertRaises(ValueError):
523
def test_tensor_subclass_propagation(self):
524
"""this test exercises the functionality described in
525
docs/source/notes/extending.rst#subclassing-torchtensor"""
526
t1 = torch.tensor([5])
527
t2 = torch.tensor([6])
532
ss1 = SubSubTensor2([5])
533
ss2 = SubSubTensor2([6])
535
sn1 = SubTensor3([5])
536
sn2 = SubTensor3([6])
539
self.assertTrue(isinstance(s1 + t2, SubTensor2))
540
self.assertTrue(isinstance(t1 + s2, SubTensor2))
541
self.assertTrue(isinstance(s1 + s2, SubTensor2))
544
self.assertTrue(isinstance(s1[0], SubTensor2))
547
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
548
self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
549
self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
550
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
551
self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
552
self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
553
self.assertTrue(isinstance(ss1[0], SubSubTensor2))
556
with self.assertRaises(TypeError):
558
with self.assertRaises(TypeError):
563
class DummyTensor(torch.Tensor):
568
self.assertTrue(c._is_view())
569
self.assertTrue(c._base is a)
583
def __torch_function__(cls, func, types, args=(), kwargs=None):
584
inputs, outputs = args
585
self.assertEqual(inputs, (x,))
586
self.assertEqual(outputs, (x,))
590
self.assertEqual(torch.autograd.grad(x, x), -1)
592
def test_pow_rpow(self):
593
class NothingImplemented(torch.Tensor):
595
def __torch_function__(cls, func, types, args=(), kwargs=None):
596
return NotImplemented
598
class RPowOnly(torch.Tensor):
600
def __torch_function__(cls, func, types, args=(), kwargs=None):
601
if func is torch.Tensor.__rpow__:
603
return NotImplemented
605
self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
608
def generate_tensor_like_override_tests(cls):
609
from torch.testing._internal.generated.annotated_fn_args import annotated_args
611
def test_generator(func, override):
613
if is_tensor_method_or_property(func):
616
return SubTensor([5])
629
if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
630
kwargs = {"upper": True}
633
is_method = is_tensor_method_or_property(func)
635
def _simple_type_parser(func, arg_name, arg_type):
637
if arg_type == "Tensor":
638
return instance_gen()
639
elif arg_type == "TensorList" or arg_type == "ITensorListRef":
640
return [instance_gen(), instance_gen()]
641
elif arg_type == "c10::List<c10::optional<Tensor>>":
642
return [instance_gen(), instance_gen()]
643
elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
644
size = arg.get("size", 2)
649
elif arg_type == "Scalar":
651
elif arg_type == "bool":
653
elif arg_type == "Dimname":
655
elif arg_type == "DimnameList":
657
elif arg_type.startswith("int"):
659
elif arg_type in {"Stream"}:
660
return torch.Stream()
661
elif arg_type.startswith("float") or arg_type == "double":
663
elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}:
665
elif arg_type == "ScalarType":
667
elif arg_type == "c10::string_view":
669
elif arg_type == "SymInt":
674
f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
677
if func in annotated_args:
678
for arg in annotated_args[func]:
680
t = arg["simple_type"]
683
if t == "Tensor" and is_method and arg["name"] == "self":
685
func = func.__get__(instance_gen())
687
arg_to_add = _simple_type_parser(func, arg["name"], t)
688
if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True):
689
kwargs[arg["name"]] = arg_to_add
691
func_args.append(arg_to_add)
693
args = inspect.getfullargspec(override)
695
func_args = inspect.getfullargspec(func)
697
func_args = type(func_args)(**{**func_args, 'annotations': None})
698
if func_args != args:
699
raise RuntimeError(f"Override for {func} doesn't match its argspec.\n"
700
+ f"Original: {inspect.signature(func)}\n"
701
+ f"Override: {inspect.signature(override)}")
704
nargs = len(args.args)
705
if args.defaults is not None:
706
nargs -= len(args.defaults)
707
func_args = [instance_gen() for _ in range(nargs)]
708
if args.varargs is not None:
709
func_args += [instance_gen(), instance_gen()]
712
ret = func(*func_args, **kwargs)
717
if not is_method or ret is None:
718
self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
721
self.assertEqual(ret, -1)
725
for func, override in get_testing_overrides().items():
726
test_method = test_generator(func, override)
727
if func.__name__ == "__get__":
744
module = "Tensor." + func.__self__.fget.__name__
748
elif is_tensor_method_or_property(func):
751
module = func.__module__
753
name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
755
name = f'test_{func.__name__}'
756
test_method.__name__ = name
757
setattr(cls, name, test_method)
759
generate_tensor_like_override_tests(TestTorchFunctionOverride)
762
"Basic data container that knows how to unwrap itself"
763
def __init__(self, data):
764
self.__dict__["_data"] = data
765
self.__dict__["used_attrs"] = set()
766
self.__dict__["used_calls"] = set()
768
def __getattr__(self, name):
769
if name in self.__dict__:
770
return self.__dict__[name]
771
self.used_attrs.add(name)
773
val = getattr(self._data, name)
776
if not isinstance(val, torch.device) and callable(val):
777
c = getattr(type(self._data), name)
780
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
782
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
786
def __setattr__(self, name, value):
787
if name in self.__dict__:
788
self.__dict__[name] = value
790
self.used_attrs.add(name)
791
setattr(self._data, name, unwrap(value))
793
def __setitem__(self, key, value):
794
self._data[unwrap(key)] = unwrap(value)
796
def __getitem__(self, key):
797
return wrap(self._data[unwrap(key)])
800
def __torch_function__(cls, func, types, args=(), kwargs=None):
804
args_of_this_cls = []
806
if isinstance(a, cls):
807
args_of_this_cls.append(a)
808
elif isinstance(a, collections.abc.Sequence):
809
args_of_this_cls.extend(el for el in a if isinstance(el, cls))
810
assert len(args_of_this_cls) > 0
811
for a in args_of_this_cls:
812
a.used_calls.add(func)
813
args = unwrap(tuple(args))
814
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
816
return wrap(func(*args, **kwargs))
818
def __add__(self, other):
819
return self.__torch_function__(torch.add, (Wrapper,), (self, other))
821
def __mul__(self, other):
822
return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
824
def __sub__(self, other):
825
return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
827
def __truediv__(self, other):
828
return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
830
def __floordiv__(self, other):
831
return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
833
def __ge__(self, other):
834
return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
836
def __gt__(self, other):
837
return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
839
def __lt__(self, other):
840
return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
842
def __le__(self, other):
843
return self.__torch_function__(torch.le, (Wrapper,), (self, other))
845
def __eq__(self, other):
846
return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
848
def __ne__(self, other):
849
return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
852
return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
855
return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
858
return len(self._data)
863
if type(v) in {tuple, list}:
864
return type(v)(unwrap(vi) for vi in v)
866
return v._data if isinstance(v, Wrapper) else v
870
if type(v) in {tuple, list}:
871
return type(v)(wrap(vi) for vi in v)
873
return Wrapper(v) if isinstance(v, torch.Tensor) else v
875
class TestEinsumOverride(TestCase):
876
"Regression test for gh-38479"
877
def test_wrapper(self):
878
x = Wrapper(torch.randn(5))
879
y = Wrapper(torch.randn(4))
880
self.assertEqual(torch.einsum('i,j->ij', x, y)._data,
881
torch.ger(x, y)._data)
884
a = Wrapper(torch.randn(2, 3))
885
b = Wrapper(torch.randn(5, 3, 7))
886
c = Wrapper(torch.randn(2, 7))
887
self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data,
888
torch.nn.functional.bilinear(a, c, b)._data)
890
class TestGradCheckOverride(TestCase):
891
"Test that wrappers work with gradcheck."
892
def test_gradcheck(self):
893
from torch.testing._internal.common_utils import gradcheck, gradgradcheck
895
def run_test(fast_mode):
896
a = wrap(torch.tensor(5.0, dtype=torch.double))
897
b = wrap(torch.tensor(6.0, dtype=torch.double))
899
a.requires_grad = True
900
b.requires_grad = True
902
gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
903
gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
905
total_used_attrs = a.used_attrs.union(b.used_attrs)
906
total_used_calls = a.used_calls.union(b.used_calls)
912
expected_used_attrs = {
926
expected_used_attrs.add('is_complex')
927
expected_used_attrs.add('device')
928
self.assertEqual(expected_used_attrs, total_used_attrs)
930
expected_used_calls = {
931
torch.Tensor.new_zeros,
933
torch.Tensor.is_floating_point,
936
torch.Tensor.requires_grad_,
941
expected_used_calls.add(torch.Tensor.is_complex)
942
self.assertEqual(expected_used_calls, total_used_calls)
943
run_test(fast_mode=True)
944
run_test(fast_mode=False)
946
class TestNamedTuple(TestCase):
947
""" Regression test for gh-47090 """
949
x = torch.tensor([1, 2])
950
xs = x.as_subclass(SubTensor2)
951
r = torch.max(x, dim=0)
952
rs = torch.max(xs, dim=0)
953
self.assertEqual(type(r), type(rs))
954
self.assertEqual(r, rs)
956
class TestGradNewOnesOverride(TestCase):
957
""" Regression test for gh-47069 """
958
def test_newones(self):
959
t = torch.tensor([1, 2]).as_subclass(SubTensor2)
960
n = t.new_ones((1, 2))
961
self.assertEqual(type(n), SubTensor2)
963
class TestPickle(TestCase):
964
"Regression test for gh-47051"
965
def test_pickle(self):
966
t = torch.tensor([1]).as_subclass(SubTensor2)
968
t2 = pickle.loads(pickle.dumps(t))
969
self.assertIs(type(t2), SubTensor2)
970
self.assertEqual(t2.abcd, "e")
972
class TestBroadcastAllOverride(TestCase):
973
""" test for gh-37141 """
974
def test_broadcast_all(self):
975
from torch.distributions.utils import broadcast_all
976
a = torch.tensor([1.2, 3.4, 5.6])
978
b = torch.tensor(5.0)
980
c = torch.tensor([5.0, 5.0, 5.0])
982
o_1 = broadcast_all(a_w, b_w)
983
self.assertTrue(isinstance(o_1[0], Wrapper))
984
self.assertTrue(isinstance(o_1[1], Wrapper))
985
self.assertEqual(o_1[0]._data, a)
986
self.assertEqual(o_1[1]._data, c)
988
o_2 = broadcast_all(a_w, b)
989
self.assertTrue(isinstance(o_2[0], Wrapper))
990
self.assertTrue(isinstance(o_2[1], Wrapper))
991
self.assertEqual(o_2[0]._data, a)
992
self.assertEqual(o_2[1]._data, c)
994
class TestWrapTorchFunction(TestCase):
995
def test_wrap_torch_function(self):
998
def __torch_function__(cls, func, types, args, kwargs):
1004
@torch.overrides.wrap_torch_function(dispatcher)
1008
self.assertEqual(f(A()), -1)
1010
class TestIndexing(TestCase):
1011
""" Regression tests for gh-46277 """
1012
def test_getitem(self):
1015
def __torch_function__(cls, func, types, args, kwargs=None):
1018
t = torch.tensor([5])
1019
self.assertEqual(t[A()], -1)
1020
self.assertEqual(t, torch.tensor([5]))
1022
def test_getitem_subclass(self):
1023
class A(torch.Tensor):
1025
def __torch_function__(cls, func, types, args, kwargs=None):
1028
t = torch.tensor([5])
1029
self.assertEqual(t[A()], -1)
1030
self.assertEqual(t[5, A()], -1)
1031
self.assertEqual(t, torch.tensor([5]))
1033
def test_setitem(self):
1038
def __torch_function__(cls, func, types, args, kwargs=None):
1042
t = torch.tensor([5])
1045
self.assertIn(Tensor.__setitem__, triggered)
1046
self.assertEqual(t, torch.tensor([5]))
1048
def test_setitem_val(self):
1053
def __torch_function__(cls, func, types, args, kwargs=None):
1057
t = torch.tensor([5])
1059
self.assertIn(Tensor.__setitem__, triggered)
1060
self.assertEqual(t, torch.tensor([5]))
1062
def test_setitem_subclass(self):
1065
class A(torch.Tensor):
1067
def __torch_function__(cls, func, types, args, kwargs=None):
1071
t = torch.tensor([5])
1074
self.assertIn(Tensor.__setitem__, triggered)
1075
self.assertEqual(t, torch.tensor([5]))
1078
class TestIterator(TestCase):
1080
def test_iterator(self):
1081
t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2)
1083
self.assertIs(type(next(it)), SubTensor2)
1084
self.assertIs(type(next(it)), SubTensor2)
1085
self.assertIs(type(next(it)), SubTensor2)
1088
class TestRNN(TestCase):
1091
model = torch.nn.RNN(10, 20, 2)
1092
input = Wrapper(torch.randn(1, 5, 10))
1096
class TestDisabledTorchFunction(TestCase):
1098
def test_parameter_does_not_prevent_dispatch(self):
1101
def __torch_function__(cls, func, types, args=(), kwargs=None):
1105
t2 = torch.nn.Parameter(torch.rand(2, 2))
1106
self.assertEqual(torch.add(t2, t1), "called")
1108
inp = torch.rand(10, 10)
1109
self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
1110
self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
1112
class TestResolveName(TestCase):
1113
def test_resolve_name(self):
1114
for cs in get_overridable_functions().values():
1117
eval(torch.overrides.resolve_name(c)),
1119
msg=f"{c}, {torch.overrides.resolve_name(c)}"
1122
class TestTorchFunctionWarning(TestCase):
1123
def test_warn_on_invalid_torch_function(self):
1125
def __torch_function__(self, *args, **kwargs):
1128
class Bad2(torch.Tensor):
1129
def __torch_function__(self, *args, **kwargs):
1133
for a in (Bad1(), Bad2()):
1134
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1136
torch.nn.functional.dropout(a)
1138
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1142
class TestDisabledUserWarnings(TestCase):
1143
def test_no_implicit_user_warning_for_deprecated_functions(self):
1144
self.assertNotWarn(get_ignored_functions)
1145
self.assertNotWarn(get_testing_overrides)
1146
self.assertNotWarn(get_overridable_functions)
1147
self.assertNotWarn(lambda: resolve_name(torch.Tensor.add))
1148
self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add))
1150
@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref")
1151
class TestTorchFunctionMode(TestCase):
1152
def test_basic(self):
1153
class A(TorchFunctionMode):
1154
def __torch_function__(self, *args, **kwargs):
1159
self.assertEqual(torch.randn(3), -1)
1160
self.assertEqual(torch.add(x, x), -1)
1161
self.assertEqual(torch.split(None, [2]), -1)
1162
self.assertEqual(bar(x), -1)
1164
def test_factory_override(self):
1165
class A(TorchFunctionMode):
1166
def __torch_function__(self, *args, **kwargs):
1170
self.assertEqual(torch.tensor([1]), -1)
1171
self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
1172
self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
1173
self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1)
1174
self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1)
1175
self.assertEqual(torch.as_tensor([1]), -1)
1177
def test_modes_handle_first(self):
1178
class A(TorchFunctionMode):
1179
def __torch_function__(self, *args, **kwargs):
1184
self.assertEqual(torch.neg(x), -40)
1185
self.assertEqual(torch.mean(x), -40)
1186
self.assertEqual(torch.mm(x, x), -40)
1187
self.assertEqual(bar(x), -40)
1189
def test_modes_return_notimplemented(self):
1190
class MyMode(TorchFunctionMode):
1191
def __torch_function__(self, *args, **kwargs):
1192
return NotImplemented
1196
self.assertEqual(torch.mean(x), 0)
1197
self.assertEqual(torch.mm(x, x), -1)
1198
self.assertEqual(bar(x), 1)
1199
self.assertRaisesRegex(
1200
TypeError, r'SubTensor',
1201
lambda: self.assertEqual(torch.max(x, x)))
1203
def test_with_mode(self):
1204
class ErrorA(RuntimeError):
1207
class A(TorchFunctionMode):
1208
def __torch_function__(self, *args, **kwargs):
1211
with self.assertRaises(ErrorA):
1215
def test_with_mode_created_separately(self):
1216
class ErrorA(RuntimeError):
1219
class A(TorchFunctionMode):
1220
def __torch_function__(self, *args, **kwargs):
1224
with self.assertRaises(ErrorA):
1228
def test_with_nested_modes(self):
1231
class A(TorchFunctionMode):
1232
def __init__(self, msg):
1235
def __torch_function__(self, func, _, args=(), kwargs=None):
1238
out.append(self.msg)
1239
return func(*args, **kwargs)
1245
self.assertEqual(out, ["layer2", "layer1"])
1247
def test_nested_same_mode(self):
1250
class A(TorchFunctionMode):
1251
def __init__(self, msg):
1254
def __torch_function__(self, func, _, args=(), kwargs=None):
1257
out.append(self.msg)
1258
return func(*args, **kwargs)
1260
with A("layer1") as a:
1264
self.assertEqual(out, ["layer1", "layer1"])
1266
def test_error_using_class_method_on_mode(self):
1267
class A(TorchFunctionMode):
1269
def __torch_function__(cls, func, _, args=(), kwargs=None):
1270
return func(args, kwargs)
1272
x = torch.tensor(5.)
1273
with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"):
1277
def test_restacking_with_ancestor(self):
1278
class A(TorchFunctionMode):
1288
def test_get_cur_mode(self):
1289
class A(TorchFunctionMode):
1290
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1294
self.assertEqual(_get_current_function_mode(), mode1)
1298
self.assertEqual(_get_current_function_mode(), mode2)
1301
def test_get_mode_stack(self):
1302
class A(TorchFunctionMode):
1303
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1306
self.assertEqual(_get_current_function_mode_stack(), [])
1309
self.assertEqual(_get_current_function_mode_stack(), [mode1])
1313
self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
1315
def test_all_same_mode(self):
1316
class A(TorchFunctionMode):
1321
self.assertTrue(all_same_mode([x, x, x]))
1322
self.assertFalse(all_same_mode([x, None]))
1323
self.assertFalse(all_same_mode([x, y]))
1325
def test_nested_modes_with_python_has_torch_function(self):
1328
class A(TorchFunctionMode):
1329
def __torch_function__(self, func, types, args=(), kwargs=None):
1331
kwargs = {} if kwargs is None else kwargs
1332
return func(*args, **kwargs)
1334
class B(TorchFunctionMode):
1335
def __torch_function__(self, func, types, args=(), kwargs=None):
1337
kwargs = {} if kwargs is None else kwargs
1338
return func(*args, **kwargs)
1340
x = torch.randn(3, 4)
1345
self.assertEqual(y, x)
1346
self.assertEqual(called, ["B", "A"])
1349
def test_reentrant_mode_idiom(self):
1352
class A(TorchFunctionMode):
1353
def __torch_function__(self, func, types, args=(), kwargs=None):
1357
if func is torch.sub:
1361
return torch.add(input, other, alpha=-1)
1362
return func(*args, **kwargs)
1369
self.assertEqual(log, [torch.sub, torch.add])
1371
def test_nn_parse_to(self):
1377
class A(TorchFunctionMode):
1378
def __torch_function__(self, func, types, args=(), kwargs=None):
1383
return func(*args, **kwargs)
1386
torch._C._nn._parse_to('cpu')
1388
self.assertTrue(called)
1390
def test_distributions_bernoulli(self):
1398
class A(TorchFunctionMode):
1399
def __torch_function__(self, func, types, args=(), kwargs=None):
1404
return func(*args, **kwargs)
1407
torch.distributions.Bernoulli(0.3)
1409
self.assertTrue(called)
1411
def test_mode_notimplemented_loop(self):
1418
class A(TorchFunctionMode):
1419
def __torch_function__(self, func, types, args=(), kwargs=None):
1429
if any(t is not torch.Tensor for t in types):
1430
return NotImplemented
1432
return func(*args, **kwargs)
1434
class B(torch.Tensor):
1442
self.assertIs(type(r), B)
1443
self.assertEqual(called, 2)
1450
self.assertIs(type(r), B)
1451
self.assertEqual(called, 2)
1453
def test_disable_subclass_not_mode(self):
1456
class A(TorchFunctionMode):
1457
def __torch_function__(self, func, types, args=(), kwargs=None):
1462
return func(*args, **kwargs)
1464
class B(torch.Tensor):
1467
x = B(torch.randn(5))
1469
with torch._C.DisableTorchFunctionSubclass():
1470
self.assertNotIsInstance(torch.sum(x), B)
1472
self.assertTrue(called)
1474
def test_disable_subclass_mode(self):
1477
class A(TorchFunctionMode):
1478
def __torch_function__(self, func, types, args=(), kwargs=None):
1483
return func(*args, **kwargs)
1485
class B(torch.Tensor):
1488
x = B(torch.randn(5))
1490
with torch._C.DisableTorchFunction():
1491
self.assertNotIsInstance(torch.sum(x), B)
1493
self.assertFalse(called)
1495
def test_disable_enable_subclass(self):
1498
class A(torch.Tensor):
1501
x = A(torch.randn(5))
1502
with torch._C.DisableTorchFunctionSubclass():
1503
g = torch._C._EnableTorchFunction()
1505
self.assertIsInstance(torch.sum(x), A)
1509
def test_subclass_hash(self):
1510
class DiagTensor(torch.Tensor):
1511
def __init__(self, diag):
1515
def __torch_function__(cls, func, types, args=(), kwargs=None):
1516
kwargs = kwargs or {}
1518
def get_full_matrices(t):
1519
if isinstance(t, DiagTensor):
1520
return torch.diag_embed(t._diag)
1524
return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs))
1529
self.assertEqual((a + 1), torch.diag_embed(d) + 1)
1538
s.add(DiagTensor(d))
1540
def test_custom_device_type(self):
1541
class CustomDeviceContext(TorchFunctionMode):
1543
def __torch_function__(self, func, types, args=(), kwargs=None):
1544
kwargs = kwargs or {}
1545
if func == torch.device:
1546
if args and isinstance(args[0], int):
1547
args = ("xla", args[0])
1548
elif isinstance(kwargs.get('device'), int):
1549
kwargs['device'] = f"xla:{kwargs.get('device')}"
1550
return func(*args, **kwargs)
1552
with CustomDeviceContext():
1553
d_args = torch.device(0)
1554
self.assertEqual(d_args.type, "xla")
1555
self.assertEqual(d_args.index, 0)
1556
d_kwargs = torch.device(device=0)
1557
self.assertEqual(d_kwargs.type, "xla")
1558
self.assertEqual(d_kwargs.index, 0)
1561
if __name__ == '__main__':