1
# mypy: allow-untyped-defs
2
from __future__ import annotations
6
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
7
from collections.abc import MutableMapping, Mapping
8
from torch import Tensor
9
import collections.abc as _collections_abc
15
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
18
# This file defines a variant of WeakKeyDictionary that overrides the hashing
19
# behavior of the key to use object identity, rather than the builtin
20
# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
21
# __eq__ implementation return a Tensor (elementwise equality), which means
22
# you can't use them directly with the WeakKeyDictionary in standard library.
24
# Our implementation strategy is to create a wrapper weak key object, which we
25
# use as a key in a stock Python dictionary. This is similar to how weakref
26
# implements WeakKeyDictionary, but instead of using weakref.ref as the
27
# wrapper, we use a custom wrapper that has different __eq__ and __hash__
28
# behavior. Note that we subsequently store this weak key directly in an
29
# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
30
# be a dictionary so it would have no strong references. Ensuring that
31
# only live WeakIdKeys are in the map is handled by putting finalizers on the
35
# It is simpler to implement this with composition, but if we want to
36
# directly reuse the callback mechanism on weakref, we need the weakref
37
# and the key to be exactly the same object. Reusing the callback mechanism
38
# minimizes the divergence between our implementation and Lib/weakref.py
40
# NB: Prefer using this when working with weakrefs of Tensors; e.g., do
41
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
42
# easy to get wrong cases transparently for you.
43
class WeakIdRef(weakref.ref):
46
def __init__(self, key, callback=None):
47
# Unlike stock weakref, which preserves hash semantics of the
48
# original object but lazily defers hash calls until the first
49
# time the user attempts to hash the weakref, we can eagerly
50
# cache the id of the key as we know this is definitely the hash
53
super().__init__(key, callback) # type: ignore[call-arg]
56
r = super().__call__()
57
# Special logic for Tensor PyObject resurrection
58
if hasattr(r, '_fix_weakref'):
59
r._fix_weakref() # type: ignore[union-attr]
65
def __eq__(self, other):
66
# An attractive but wrong alternate implementation is to only test if
67
# the stored _ids match. This can lead to an ABA problem if you have:
72
# a2 = A() # suppose it gets the same ID as a1
76
# This should be False, as a1 and a2 are unrelated (and a1 is
80
if a is not None and b is not None:
84
# This is the same as WeakIdRef but equality is checked using hash() rather than id.
85
# This will be equivalent to the one above except for classes where hash is not their id.
86
class _WeakHashRef(weakref.ref):
89
def __init__(self, key, callback=None):
90
# Unlike stock weakref, which preserves hash semantics of the
91
# original object but lazily defers hash calls until the first
92
# time the user attempts to hash the weakref, we can eagerly
93
# cache the id of the key as we know this is definitely the hash
96
super().__init__(key, callback) # type: ignore[call-arg]
99
r = super().__call__()
100
# Special logic for Tensor PyObject resurrection
101
if hasattr(r, '_fix_weakref'):
102
r._fix_weakref() # type: ignore[union-attr]
108
def __eq__(self, other):
109
# Use hash equality to determine ref equality.
110
# ScriptObject implements __hash__ to return the wrapped IValue's id, so
111
# this is equivalent to doing an identity comparison.
114
if a is not None and b is not None:
115
return hash(a) == hash(b)
118
# This is directly adapted from cpython/Lib/weakref.py
119
class WeakIdKeyDictionary(MutableMapping):
120
def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED
123
self.ref_type = ref_type # CHANGED
125
def remove(k, selfref=ref(self)):
129
self._pending_removals.append(k)
135
self._remove = remove
136
# A list of dead weakrefs (keys to be removed)
137
self._pending_removals = []
138
self._iterating = set()
139
self._dirty_len = False
143
def _commit_removals(self):
144
# NOTE: We don't need to call this method before mutating the dict,
145
# because a dead weakref never compares equal to a live weakref,
146
# even if they happened to refer to equal objects.
147
# However, it means keys may already have been removed.
148
pop = self._pending_removals.pop
161
def _scrub_removals(self):
163
self._pending_removals = [k for k in self._pending_removals if k in d]
164
self._dirty_len = False
166
def __delitem__(self, key):
167
self._dirty_len = True
168
del self.data[self.ref_type(key)] # CHANGED
170
def __getitem__(self, key):
171
return self.data[self.ref_type(key)] # CHANGED
174
if self._dirty_len and self._pending_removals:
175
# self._pending_removals may still contain keys which were
176
# explicitly removed, we have to scrub them (see issue #21173).
177
self._scrub_removals()
178
return len(self.data) - len(self._pending_removals)
181
return f"<{self.__class__.__name__} at {id(self):#x}>"
183
def __setitem__(self, key, value):
184
self.data[self.ref_type(key, self._remove)] = value # CHANGED
187
new = WeakIdKeyDictionary()
188
with _IterationGuard(self):
189
for key, value in self.data.items():
197
def __deepcopy__(self, memo):
198
from copy import deepcopy
199
new = self.__class__()
200
with _IterationGuard(self):
201
for key, value in self.data.items():
204
new[o] = deepcopy(value, memo)
207
def get(self, key, default=None):
208
return self.data.get(self.ref_type(key), default) # CHANGED
210
def __contains__(self, key):
212
wr = self.ref_type(key) # CHANGED
215
return wr in self.data
218
with _IterationGuard(self):
219
for wr, value in self.data.items():
225
with _IterationGuard(self):
234
with _IterationGuard(self):
235
for wr, value in self.data.items():
240
"""Return a list of weak references to the keys.
242
The references are not guaranteed to be 'live' at the time
243
they are used, so the result of calling the references needs
244
to be checked before being used. This can be used to avoid
245
creating references that will cause the garbage collector to
246
keep the keys around longer than needed.
249
return list(self.data)
252
self._dirty_len = True
254
key, value = self.data.popitem()
259
def pop(self, key, *args):
260
self._dirty_len = True
261
return self.data.pop(self.ref_type(key), *args) # CHANGED
263
def setdefault(self, key, default=None):
264
return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED
266
def update(self, dict=None, **kwargs):
269
if not hasattr(dict, "items"):
270
dict = type({})(dict)
271
for key, value in dict.items():
272
d[self.ref_type(key, self._remove)] = value # CHANGED
276
def __ior__(self, other):
280
def __or__(self, other):
281
if isinstance(other, _collections_abc.Mapping):
285
return NotImplemented
287
def __ror__(self, other):
288
if isinstance(other, _collections_abc.Mapping):
293
return NotImplemented
295
# Default Mapping equality will tests keys for equality, but
296
# we want to test ids for equality
297
def __eq__(self, other):
298
if not isinstance(other, Mapping):
299
return NotImplemented
300
return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
303
WeakTensorKeyDictionary = WeakIdKeyDictionary
307
"""Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
311
def __init__(self, tensor: Tensor):
312
assert isinstance(tensor, Tensor)
313
self.ref = weakref.ref(tensor)
319
assert isinstance(out, Tensor)
320
# TODO, add _fix_weakref type binding
321
out._fix_weakref() # type: ignore[attr-defined]