1
# Owner(s): ["module: meta tensors"]
11
from torch.testing._internal.common_utils import (
12
find_library_location,
20
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
27
# These tests are ported from cpython/Lib/test/test_weakref.py,
28
# but adapted to use tensor rather than object
29
class WeakTest(TestCase):
32
def test_make_weak_keyed_dict_from_dict(self):
34
dict = WeakIdKeyDictionary({o: 364})
35
self.assertEqual(dict[o], 364)
37
def test_make_weak_keyed_dict_from_weak_keyed_dict(self):
39
dict = WeakIdKeyDictionary({o: 364})
40
dict2 = WeakIdKeyDictionary(dict)
41
self.assertEqual(dict[o], 364)
43
def check_popitem(self, klass, key1, value1, key2, value2):
45
weakdict[key1] = value1
46
weakdict[key2] = value2
47
self.assertEqual(len(weakdict), 2)
48
k, v = weakdict.popitem()
49
self.assertEqual(len(weakdict), 1)
51
self.assertIs(v, value1)
53
self.assertIs(v, value2)
54
k, v = weakdict.popitem()
55
self.assertEqual(len(weakdict), 0)
57
self.assertIs(v, value1)
59
self.assertIs(v, value2)
61
def test_weak_keyed_dict_popitem(self):
62
self.check_popitem(WeakIdKeyDictionary, C(), "value 1", C(), "value 2")
64
def check_setdefault(self, klass, key, value1, value2):
68
"invalid test -- value parameters must be distinct objects",
71
o = weakdict.setdefault(key, value1)
72
self.assertIs(o, value1)
73
self.assertIn(key, weakdict)
74
self.assertIs(weakdict.get(key), value1)
75
self.assertIs(weakdict[key], value1)
77
o = weakdict.setdefault(key, value2)
78
self.assertIs(o, value1)
79
self.assertIn(key, weakdict)
80
self.assertIs(weakdict.get(key), value1)
81
self.assertIs(weakdict[key], value1)
83
def test_weak_keyed_dict_setdefault(self):
84
self.check_setdefault(WeakIdKeyDictionary, C(), "value 1", "value 2")
86
def check_update(self, klass, dict):
88
# This exercises d.update(), len(d), d.keys(), k in d,
93
self.assertEqual(len(weakdict), len(dict))
94
for k in weakdict.keys():
95
self.assertIn(k, dict, "mysterious new key appeared in weak dict")
97
self.assertIs(v, weakdict[k])
98
self.assertIs(v, weakdict.get(k))
100
self.assertIn(k, weakdict, "original key disappeared in weak dict")
102
self.assertIs(v, weakdict[k])
103
self.assertIs(v, weakdict.get(k))
105
def test_weak_keyed_dict_update(self):
106
self.check_update(WeakIdKeyDictionary, {C(): 1, C(): 2, C(): 3})
108
def test_weak_keyed_delitem(self):
109
d = WeakIdKeyDictionary()
114
self.assertEqual(len(d), 2)
116
self.assertEqual(len(d), 1)
117
self.assertEqual(list(d.keys()), [o2])
119
def test_weak_keyed_union_operators(self):
123
self.skipTest("dict union not supported in this Python")
128
wkd1 = WeakIdKeyDictionary({o1: 1, o2: 2})
129
wkd2 = WeakIdKeyDictionary({o3: 3, o1: 4})
131
d1 = {o2: "5", o3: "6"}
132
pairs = [(o2, 7), (o3, 8)]
134
tmp1 = wkd1 | wkd2 # Between two WeakKeyDictionaries
135
self.assertEqual(dict(tmp1), dict(wkd1) | dict(wkd2))
136
self.assertIs(type(tmp1), WeakIdKeyDictionary)
138
self.assertEqual(wkd1, tmp1)
140
tmp2 = wkd2 | d1 # Between WeakKeyDictionary and mapping
141
self.assertEqual(dict(tmp2), dict(wkd2) | d1)
142
self.assertIs(type(tmp2), WeakIdKeyDictionary)
144
self.assertEqual(wkd2, tmp2)
146
tmp3 = wkd3.copy() # Between WeakKeyDictionary and iterable key, value
148
self.assertEqual(dict(tmp3), dict(wkd3) | dict(pairs))
149
self.assertIs(type(tmp3), WeakIdKeyDictionary)
151
tmp4 = d1 | wkd3 # Testing .__ror__
152
self.assertEqual(dict(tmp4), d1 | dict(wkd3))
153
self.assertIs(type(tmp4), WeakIdKeyDictionary)
156
self.assertNotIn(4, tmp1.values())
157
self.assertNotIn(4, tmp2.values())
158
self.assertNotIn(1, tmp3.values())
159
self.assertNotIn(1, tmp4.values())
161
def test_weak_keyed_bad_delitem(self):
162
d = WeakIdKeyDictionary()
164
# An attempt to delete an object that isn't there should raise
165
# KeyError. It didn't before 2.3.
166
self.assertRaises(KeyError, d.__delitem__, o)
167
self.assertRaises(KeyError, d.__getitem__, o)
169
# If a key isn't of a weakly referencable type, __getitem__ and
170
# __setitem__ raise TypeError. __delitem__ should too.
171
self.assertRaises(TypeError, d.__delitem__, 13)
172
self.assertRaises(TypeError, d.__getitem__, 13)
173
self.assertRaises(TypeError, d.__setitem__, 13, 13)
175
def test_make_weak_keyed_dict_repr(self):
176
dict = WeakIdKeyDictionary()
177
self.assertRegex(repr(dict), "<WeakIdKeyDictionary at 0x.*>")
179
def check_threaded_weak_dict_copy(self, type_, deepcopy):
180
# `deepcopy` should be either True or False.
183
# Cannot give these slots as weakrefs weren't supported
184
# on these objects until later versions of Python
185
class DummyKey: # noqa: B903
186
def __init__(self, ctr):
189
class DummyValue: # noqa: B903
190
def __init__(self, ctr):
193
def dict_copy(d, exc):
199
except Exception as ex:
202
def pop_and_collect(lst):
205
i = random.randint(0, len(lst) - 1)
208
if gc_ctr % 10000 == 0:
209
gc.collect() # just in case
214
# Initialize d with many entries
215
for i in range(70000):
216
k, v = DummyKey(i), DummyValue(i)
223
t_copy = threading.Thread(
230
t_collect = threading.Thread(target=pop_and_collect, args=(keys,))
242
def test_threaded_weak_key_dict_copy(self):
243
# Issue #35615: Weakref keys or values getting GC'ed during dict
244
# copying should not result in a crash.
245
self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, False)
247
def test_threaded_weak_key_dict_deepcopy(self):
248
# Issue #35615: Weakref keys or values getting GC'ed during dict
249
# copying should not result in a crash.
250
self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, True)
253
# Adapted from cpython/Lib/test/mapping_tests.py
254
class WeakKeyDictionaryTestCase(TestCase):
255
__ref = {torch.randn(1): 1, torch.randn(2): 2, torch.randn(3): 3}
256
type2test = WeakIdKeyDictionary
258
def _reference(self):
259
return self.__ref.copy()
261
def _empty_mapping(self):
262
"""Return an empty mapping object"""
263
return self.type2test()
265
def _full_mapping(self, data):
266
"""Return a mapping object with the value contained in data
268
x = self._empty_mapping()
269
for key, value in data.items():
273
def __init__(self, *args, **kw):
274
unittest.TestCase.__init__(self, *args, **kw)
275
self.reference = self._reference().copy()
277
# A (key, value) pair not in the mapping
278
key, value = self.reference.popitem()
279
self.other = {key: value}
281
# A (key, value) pair in the mapping
282
key, value = self.reference.popitem()
283
self.inmapping = {key: value}
284
self.reference[key] = value
287
# Test for read only operations on mapping
288
p = self._empty_mapping()
289
p1 = dict(p) # workaround for singleton objects
290
d = self._full_mapping(self.reference)
294
for key, value in self.reference.items():
295
self.assertEqual(d[key], value)
296
knownkey = next(iter(self.other.keys()))
297
self.assertRaises(KeyError, lambda: d[knownkey])
299
self.assertEqual(len(p), 0)
300
self.assertEqual(len(d), len(self.reference))
302
for k in self.reference:
305
self.assertNotIn(k, d)
309
) # NB: don't use assertEqual, that doesn't actually use ==
310
self.assertTrue(d == d)
311
self.assertTrue(p != d)
312
self.assertTrue(d != p)
315
self.fail("Empty mapping must compare to False")
317
self.fail("Full mapping must compare to True")
319
# keys(), items(), iterkeys() ...
320
def check_iterandlist(iter, lst, ref):
321
self.assertTrue(hasattr(iter, "__next__"))
322
self.assertTrue(hasattr(iter, "__iter__"))
324
self.assertTrue(set(x) == set(lst) == set(ref))
326
check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys())
327
check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
328
check_iterandlist(iter(d.values()), list(d.values()), self.reference.values())
329
check_iterandlist(iter(d.items()), list(d.items()), self.reference.items())
331
key, value = next(iter(d.items()))
332
knownkey, knownvalue = next(iter(self.other.items()))
333
self.assertEqual(d.get(key, knownvalue), value)
334
self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
335
self.assertNotIn(knownkey, d)
337
def test_write(self):
338
# Test for write operations on mapping
339
p = self._empty_mapping()
341
for key, value in self.reference.items():
343
self.assertEqual(p[key], value)
344
for key in self.reference.keys():
346
self.assertRaises(KeyError, lambda: p[key])
347
p = self._empty_mapping()
349
p.update(self.reference)
350
self.assertEqual(dict(p), self.reference)
351
items = list(p.items())
352
p = self._empty_mapping()
354
self.assertEqual(dict(p), self.reference)
355
d = self._full_mapping(self.reference)
357
key, value = next(iter(d.items()))
358
knownkey, knownvalue = next(iter(self.other.items()))
359
self.assertEqual(d.setdefault(key, knownvalue), value)
360
self.assertEqual(d[key], value)
361
self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
362
self.assertEqual(d[knownkey], knownvalue)
364
self.assertEqual(d.pop(knownkey), knownvalue)
365
self.assertNotIn(knownkey, d)
366
self.assertRaises(KeyError, d.pop, knownkey)
368
d[knownkey] = knownvalue
369
self.assertEqual(d.pop(knownkey, default), knownvalue)
370
self.assertNotIn(knownkey, d)
371
self.assertEqual(d.pop(knownkey, default), default)
373
key, value = d.popitem()
374
self.assertNotIn(key, d)
375
self.assertEqual(value, self.reference[key])
376
p = self._empty_mapping()
377
self.assertRaises(KeyError, p.popitem)
379
def test_constructor(self):
380
self.assertEqual(self._empty_mapping(), self._empty_mapping())
383
self.assertTrue(not self._empty_mapping())
384
self.assertTrue(self.reference)
385
self.assertTrue(bool(self._empty_mapping()) is False)
386
self.assertTrue(bool(self.reference) is True)
389
d = self._empty_mapping()
390
self.assertEqual(list(d.keys()), [])
392
self.assertIn(next(iter(self.inmapping.keys())), d.keys())
393
self.assertNotIn(next(iter(self.other.keys())), d.keys())
394
self.assertRaises(TypeError, d.keys, None)
396
def test_values(self):
397
d = self._empty_mapping()
398
self.assertEqual(list(d.values()), [])
400
self.assertRaises(TypeError, d.values, None)
402
def test_items(self):
403
d = self._empty_mapping()
404
self.assertEqual(list(d.items()), [])
406
self.assertRaises(TypeError, d.items, None)
409
d = self._empty_mapping()
410
self.assertEqual(len(d), 0)
412
def test_getitem(self):
415
d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values()))
418
self.assertRaises(TypeError, d.__getitem__)
420
def test_update(self):
422
d = self._empty_mapping()
424
self.assertEqual(list(d.items()), list(self.other.items()))
427
d = self._empty_mapping()
429
self.assertEqual(d, self._empty_mapping())
432
d = self._empty_mapping()
433
d.update(self.other.items())
434
self.assertEqual(list(d.items()), list(self.other.items()))
437
d = self._empty_mapping()
438
d.update(self.other.items())
439
self.assertEqual(list(d.items()), list(self.other.items()))
441
# FIXME: Doesn't work with UserDict
442
# self.assertRaises((TypeError, AttributeError), d.update, None)
443
self.assertRaises((TypeError, AttributeError), d.update, 42)
447
class SimpleUserDict:
449
self.d = outerself.reference
454
def __getitem__(self, i):
458
d.update(SimpleUserDict())
459
i1 = sorted((id(k), v) for k, v in d.items())
460
i2 = sorted((id(k), v) for k, v in self.reference.items())
461
self.assertEqual(i1, i2)
463
class Exc(Exception):
466
d = self._empty_mapping()
468
class FailingUserDict:
472
self.assertRaises(Exc, d.update, FailingUserDict())
476
class FailingUserDict:
493
def __getitem__(self, key):
496
self.assertRaises(Exc, d.update, FailingUserDict())
498
class FailingUserDict:
508
if self.i <= ord("z"):
516
def __getitem__(self, key):
519
self.assertRaises(Exc, d.update, FailingUserDict())
521
d = self._empty_mapping()
530
self.assertRaises(Exc, d.update, badseq())
532
self.assertRaises(ValueError, d.update, [(1, 2, 3)])
534
# no test_fromkeys or test_copy as both os.environ and selves don't support it
537
d = self._empty_mapping()
538
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
539
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
541
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
542
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
544
d.get(next(iter(self.inmapping.keys()))),
545
next(iter(self.inmapping.values())),
548
d.get(next(iter(self.inmapping.keys())), 3),
549
next(iter(self.inmapping.values())),
551
self.assertRaises(TypeError, d.get)
552
self.assertRaises(TypeError, d.get, None, None, None)
554
def test_setdefault(self):
555
d = self._empty_mapping()
556
self.assertRaises(TypeError, d.setdefault)
558
def test_popitem(self):
559
d = self._empty_mapping()
560
self.assertRaises(KeyError, d.popitem)
561
self.assertRaises(TypeError, d.popitem, 42)
564
d = self._empty_mapping()
565
k, v = next(iter(self.inmapping.items()))
567
self.assertRaises(KeyError, d.pop, next(iter(self.other.keys())))
569
self.assertEqual(d.pop(k), v)
570
self.assertEqual(len(d), 0)
572
self.assertRaises(KeyError, d.pop, k)
575
# Adapted from cpython/Lib/test/mapping_tests.py
576
class WeakKeyDictionaryScriptObjectTestCase(TestCase):
577
def _reference(self):
579
torch.classes._TorchScriptTesting._Foo(1, 2): 1,
580
torch.classes._TorchScriptTesting._Foo(2, 3): 2,
581
torch.classes._TorchScriptTesting._Foo(3, 4): 3,
583
return self.__ref.copy()
585
def _empty_mapping(self):
586
"""Return an empty mapping object"""
587
return WeakIdKeyDictionary(ref_type=_WeakHashRef)
589
def _full_mapping(self, data):
590
"""Return a mapping object with the value contained in data
592
x = self._empty_mapping()
593
for key, value in data.items():
599
raise unittest.SkipTest("non-portable load_library call used in test")
601
def __init__(self, *args, **kw):
602
unittest.TestCase.__init__(self, *args, **kw)
603
if IS_SANDCASTLE or IS_FBCODE:
604
torch.ops.load_library(
605
"//caffe2/test/cpp/jit:test_custom_class_registrations"
608
# don't load the library, just skip the tests in setUp
611
lib_file_path = find_library_location("libtorchbind_test.so")
613
lib_file_path = find_library_location("torchbind_test.dll")
614
torch.ops.load_library(str(lib_file_path))
616
self.reference = self._reference().copy()
618
# A (key, value) pair not in the mapping
619
key, value = self.reference.popitem()
620
self.other = {key: value}
622
# A (key, value) pair in the mapping
623
key, value = self.reference.popitem()
624
self.inmapping = {key: value}
625
self.reference[key] = value
628
# Test for read only operations on mapping
629
p = self._empty_mapping()
630
p1 = dict(p) # workaround for singleton objects
631
d = self._full_mapping(self.reference)
635
for key, value in self.reference.items():
636
self.assertEqual(d[key], value)
637
knownkey = next(iter(self.other.keys()))
638
self.assertRaises(KeyError, lambda: d[knownkey])
640
self.assertEqual(len(p), 0)
641
self.assertEqual(len(d), len(self.reference))
643
for k in self.reference:
646
self.assertNotIn(k, d)
650
) # NB: don't use assertEqual, that doesn't actually use ==
651
self.assertTrue(d == d)
652
self.assertTrue(p != d)
653
self.assertTrue(d != p)
656
self.fail("Empty mapping must compare to False")
658
self.fail("Full mapping must compare to True")
660
# keys(), items(), iterkeys() ...
661
def check_iterandlist(iter, lst, ref):
662
self.assertTrue(hasattr(iter, "__next__"))
663
self.assertTrue(hasattr(iter, "__iter__"))
665
self.assertTrue(set(x) == set(lst) == set(ref))
667
check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys())
668
check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
669
check_iterandlist(iter(d.values()), list(d.values()), self.reference.values())
670
check_iterandlist(iter(d.items()), list(d.items()), self.reference.items())
672
key, value = next(iter(d.items()))
673
knownkey, knownvalue = next(iter(self.other.items()))
674
self.assertEqual(d.get(key, knownvalue), value)
675
self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
676
self.assertNotIn(knownkey, d)
678
def test_write(self):
679
# Test for write operations on mapping
680
p = self._empty_mapping()
682
for key, value in self.reference.items():
684
self.assertEqual(p[key], value)
685
for key in self.reference.keys():
687
self.assertRaises(KeyError, lambda: p[key])
688
p = self._empty_mapping()
690
p.update(self.reference)
691
self.assertEqual(dict(p), self.reference)
692
items = list(p.items())
693
p = self._empty_mapping()
695
self.assertEqual(dict(p), self.reference)
696
d = self._full_mapping(self.reference)
698
key, value = next(iter(d.items()))
699
knownkey, knownvalue = next(iter(self.other.items()))
700
self.assertEqual(d.setdefault(key, knownvalue), value)
701
self.assertEqual(d[key], value)
702
self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
703
self.assertEqual(d[knownkey], knownvalue)
705
self.assertEqual(d.pop(knownkey), knownvalue)
706
self.assertNotIn(knownkey, d)
707
self.assertRaises(KeyError, d.pop, knownkey)
709
d[knownkey] = knownvalue
710
self.assertEqual(d.pop(knownkey, default), knownvalue)
711
self.assertNotIn(knownkey, d)
712
self.assertEqual(d.pop(knownkey, default), default)
714
key, value = d.popitem()
715
self.assertNotIn(key, d)
716
self.assertEqual(value, self.reference[key])
717
p = self._empty_mapping()
718
self.assertRaises(KeyError, p.popitem)
720
def test_constructor(self):
721
self.assertEqual(self._empty_mapping(), self._empty_mapping())
724
self.assertTrue(not self._empty_mapping())
725
self.assertTrue(self.reference)
726
self.assertTrue(bool(self._empty_mapping()) is False)
727
self.assertTrue(bool(self.reference) is True)
730
d = self._empty_mapping()
731
self.assertEqual(list(d.keys()), [])
733
self.assertIn(next(iter(self.inmapping.keys())), d.keys())
734
self.assertNotIn(next(iter(self.other.keys())), d.keys())
735
self.assertRaises(TypeError, d.keys, None)
737
def test_values(self):
738
d = self._empty_mapping()
739
self.assertEqual(list(d.values()), [])
741
self.assertRaises(TypeError, d.values, None)
743
def test_items(self):
744
d = self._empty_mapping()
745
self.assertEqual(list(d.items()), [])
747
self.assertRaises(TypeError, d.items, None)
750
d = self._empty_mapping()
751
self.assertEqual(len(d), 0)
753
def test_getitem(self):
756
d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values()))
759
self.assertRaises(TypeError, d.__getitem__)
761
def test_update(self):
763
d = self._empty_mapping()
765
self.assertEqual(list(d.items()), list(self.other.items()))
768
d = self._empty_mapping()
770
self.assertEqual(d, self._empty_mapping())
773
d = self._empty_mapping()
774
d.update(self.other.items())
775
self.assertEqual(list(d.items()), list(self.other.items()))
778
d = self._empty_mapping()
779
d.update(self.other.items())
780
self.assertEqual(list(d.items()), list(self.other.items()))
782
# FIXME: Doesn't work with UserDict
783
# self.assertRaises((TypeError, AttributeError), d.update, None)
784
self.assertRaises((TypeError, AttributeError), d.update, 42)
788
class SimpleUserDict:
790
self.d = outerself.reference
795
def __getitem__(self, i):
799
d.update(SimpleUserDict())
800
i1 = sorted((id(k), v) for k, v in d.items())
801
i2 = sorted((id(k), v) for k, v in self.reference.items())
802
self.assertEqual(i1, i2)
804
class Exc(Exception):
807
d = self._empty_mapping()
809
class FailingUserDict:
813
self.assertRaises(Exc, d.update, FailingUserDict())
817
class FailingUserDict:
834
def __getitem__(self, key):
837
self.assertRaises(Exc, d.update, FailingUserDict())
839
class FailingUserDict:
849
if self.i <= ord("z"):
857
def __getitem__(self, key):
860
self.assertRaises(Exc, d.update, FailingUserDict())
862
d = self._empty_mapping()
871
self.assertRaises(Exc, d.update, badseq())
873
self.assertRaises(ValueError, d.update, [(1, 2, 3)])
875
# no test_fromkeys or test_copy as both os.environ and selves don't support it
878
d = self._empty_mapping()
879
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
880
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
882
self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
883
self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
885
d.get(next(iter(self.inmapping.keys()))),
886
next(iter(self.inmapping.values())),
889
d.get(next(iter(self.inmapping.keys())), 3),
890
next(iter(self.inmapping.values())),
892
self.assertRaises(TypeError, d.get)
893
self.assertRaises(TypeError, d.get, None, None, None)
895
def test_setdefault(self):
896
d = self._empty_mapping()
897
self.assertRaises(TypeError, d.setdefault)
899
def test_popitem(self):
900
d = self._empty_mapping()
901
self.assertRaises(KeyError, d.popitem)
902
self.assertRaises(TypeError, d.popitem, 42)
905
d = self._empty_mapping()
906
k, v = next(iter(self.inmapping.items()))
908
self.assertRaises(KeyError, d.pop, next(iter(self.other.keys())))
910
self.assertEqual(d.pop(k), v)
911
self.assertEqual(len(d), 0)
913
self.assertRaises(KeyError, d.pop, k)
916
if __name__ == "__main__":