1
from __future__ import annotations
5
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
6
from collections.abc import MutableMapping, Mapping
7
from torch import Tensor
8
import collections.abc as _collections_abc
14
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
17
# This file defines a variant of WeakKeyDictionary that overrides the hashing
18
# behavior of the key to use object identity, rather than the builtin
19
# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
20
# __eq__ implementation return a Tensor (elementwise equality), which means
21
# you can't use them directly with the WeakKeyDictionary in standard library.
23
# Our implementation strategy is to create a wrapper weak key object, which we
24
# use as a key in a stock Python dictionary. This is similar to how weakref
25
# implements WeakKeyDictionary, but instead of using weakref.ref as the
26
# wrapper, we use a custom wrapper that has different __eq__ and __hash__
27
# behavior. Note that we subsequently store this weak key directly in an
28
# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
29
# be a dictionary so it would have no strong references. Ensuring that
30
# only live WeakIdKeys are in the map is handled by putting finalizers on the
34
# It is simpler to implement this with composition, but if we want to
35
# directly reuse the callback mechanism on weakref, we need the weakref
36
# and the key to be exactly the same object. Reusing the callback mechanism
37
# minimizes the divergence between our implementation and Lib/weakref.py
39
# NB: Prefer using this when working with weakrefs of Tensors; e.g., do
40
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
41
# easy to get wrong cases transparently for you.
42
class WeakIdRef(weakref.ref):
45
def __init__(self, key, callback=None):
46
# Unlike stock weakref, which preserves hash semantics of the
47
# original object but lazily defers hash calls until the first
48
# time the user attempts to hash the weakref, we can eagerly
49
# cache the id of the key as we know this is definitely the hash
52
super().__init__(key, callback) # type: ignore[call-arg]
55
r = super().__call__()
56
# Special logic for Tensor PyObject resurrection
57
if hasattr(r, '_fix_weakref'):
58
r._fix_weakref() # type: ignore[union-attr]
64
def __eq__(self, other):
65
# An attractive but wrong alternate implementation is to only test if
66
# the stored _ids match. This can lead to an ABA problem if you have:
71
# a2 = A() # suppose it gets the same ID as a1
75
# This should be False, as a1 and a2 are unrelated (and a1 is
79
if a is not None and b is not None:
83
# This is the same as WeakIdRef but equality is checked using hash() rather than id.
84
# This will be equivalent to the one above except for classes where hash is not their id.
85
class _WeakHashRef(weakref.ref):
88
def __init__(self, key, callback=None):
89
# Unlike stock weakref, which preserves hash semantics of the
90
# original object but lazily defers hash calls until the first
91
# time the user attempts to hash the weakref, we can eagerly
92
# cache the id of the key as we know this is definitely the hash
95
super().__init__(key, callback) # type: ignore[call-arg]
98
r = super().__call__()
99
# Special logic for Tensor PyObject resurrection
100
if hasattr(r, '_fix_weakref'):
101
r._fix_weakref() # type: ignore[union-attr]
107
def __eq__(self, other):
108
# Use hash equality to determine ref equality.
109
# ScriptObject implements __hash__ to return the wrapped IValue's id, so
110
# this is equivalent to doing an identity comparison.
113
if a is not None and b is not None:
114
return hash(a) == hash(b)
117
# This is directly adapted from cpython/Lib/weakref.py
118
class WeakIdKeyDictionary(MutableMapping):
119
def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED
122
self.ref_type = ref_type # CHANGED
124
def remove(k, selfref=ref(self)):
128
self._pending_removals.append(k)
134
self._remove = remove
135
# A list of dead weakrefs (keys to be removed)
136
self._pending_removals = []
137
self._iterating = set()
138
self._dirty_len = False
142
def _commit_removals(self):
143
# NOTE: We don't need to call this method before mutating the dict,
144
# because a dead weakref never compares equal to a live weakref,
145
# even if they happened to refer to equal objects.
146
# However, it means keys may already have been removed.
147
pop = self._pending_removals.pop
160
def _scrub_removals(self):
162
self._pending_removals = [k for k in self._pending_removals if k in d]
163
self._dirty_len = False
165
def __delitem__(self, key):
166
self._dirty_len = True
167
del self.data[self.ref_type(key)] # CHANGED
169
def __getitem__(self, key):
170
return self.data[self.ref_type(key)] # CHANGED
173
if self._dirty_len and self._pending_removals:
174
# self._pending_removals may still contain keys which were
175
# explicitly removed, we have to scrub them (see issue #21173).
176
self._scrub_removals()
177
return len(self.data) - len(self._pending_removals)
180
return f"<{self.__class__.__name__} at {id(self):#x}>"
182
def __setitem__(self, key, value):
183
self.data[self.ref_type(key, self._remove)] = value # CHANGED
186
new = WeakIdKeyDictionary()
187
with _IterationGuard(self):
188
for key, value in self.data.items():
196
def __deepcopy__(self, memo):
197
from copy import deepcopy
198
new = self.__class__()
199
with _IterationGuard(self):
200
for key, value in self.data.items():
203
new[o] = deepcopy(value, memo)
206
def get(self, key, default=None):
207
return self.data.get(self.ref_type(key), default) # CHANGED
209
def __contains__(self, key):
211
wr = self.ref_type(key) # CHANGED
214
return wr in self.data
217
with _IterationGuard(self):
218
for wr, value in self.data.items():
224
with _IterationGuard(self):
233
with _IterationGuard(self):
234
for wr, value in self.data.items():
239
"""Return a list of weak references to the keys.
241
The references are not guaranteed to be 'live' at the time
242
they are used, so the result of calling the references needs
243
to be checked before being used. This can be used to avoid
244
creating references that will cause the garbage collector to
245
keep the keys around longer than needed.
248
return list(self.data)
251
self._dirty_len = True
253
key, value = self.data.popitem()
258
def pop(self, key, *args):
259
self._dirty_len = True
260
return self.data.pop(self.ref_type(key), *args) # CHANGED
262
def setdefault(self, key, default=None):
263
return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED
265
def update(self, dict=None, **kwargs):
268
if not hasattr(dict, "items"):
269
dict = type({})(dict)
270
for key, value in dict.items():
271
d[self.ref_type(key, self._remove)] = value # CHANGED
275
def __ior__(self, other):
279
def __or__(self, other):
280
if isinstance(other, _collections_abc.Mapping):
284
return NotImplemented
286
def __ror__(self, other):
287
if isinstance(other, _collections_abc.Mapping):
292
return NotImplemented
294
# Default Mapping equality will tests keys for equality, but
295
# we want to test ids for equality
296
def __eq__(self, other):
297
if not isinstance(other, Mapping):
298
return NotImplemented
299
return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
302
WeakTensorKeyDictionary = WeakIdKeyDictionary
306
"""Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
310
def __init__(self, tensor: Tensor):
311
assert isinstance(tensor, Tensor)
312
self.ref = weakref.ref(tensor)
318
assert isinstance(out, Tensor)
319
# TODO, add _fix_weakref type binding
320
out._fix_weakref() # type: ignore[attr-defined]