pytorch

Форк
0
/
weak.py 
321 строка · 10.8 Кб
1
from __future__ import annotations
2

3
import weakref
4
from weakref import ref
5
from _weakrefset import _IterationGuard  # type: ignore[attr-defined]
6
from collections.abc import MutableMapping, Mapping
7
from torch import Tensor
8
import collections.abc as _collections_abc
9

10

11
WeakRef = ref
12

13

14
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
15

16

17
# This file defines a variant of WeakKeyDictionary that overrides the hashing
18
# behavior of the key to use object identity, rather than the builtin
19
# __eq__/__hash__ functions.  This is useful for Tensor weak keys, as their
20
# __eq__ implementation return a Tensor (elementwise equality), which means
21
# you can't use them directly with the WeakKeyDictionary in standard library.
22
#
23
# Our implementation strategy is to create a wrapper weak key object, which we
24
# use as a key in a stock Python dictionary.  This is similar to how weakref
25
# implements WeakKeyDictionary, but instead of using weakref.ref as the
26
# wrapper, we use a custom wrapper that has different __eq__ and __hash__
27
# behavior.  Note that we subsequently store this weak key directly in an
28
# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
29
# be a dictionary so it would have no strong references.  Ensuring that
30
# only live WeakIdKeys are in the map is handled by putting finalizers on the
31
# original key object.
32

33

34
# It is simpler to implement this with composition, but if we want to
35
# directly reuse the callback mechanism on weakref, we need the weakref
36
# and the key to be exactly the same object.  Reusing the callback mechanism
37
# minimizes the divergence between our implementation and Lib/weakref.py
38
#
39
# NB: Prefer using this when working with weakrefs of Tensors; e.g., do
40
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
41
# easy to get wrong cases transparently for you.
42
class WeakIdRef(weakref.ref):
43
    __slots__ = ['_id']
44

45
    def __init__(self, key, callback=None):
46
        # Unlike stock weakref, which preserves hash semantics of the
47
        # original object but lazily defers hash calls until the first
48
        # time the user attempts to hash the weakref, we can eagerly
49
        # cache the id of the key as we know this is definitely the hash
50
        # method
51
        self._id = id(key)
52
        super().__init__(key, callback)  # type: ignore[call-arg]
53

54
    def __call__(self):
55
        r = super().__call__()
56
        # Special logic for Tensor PyObject resurrection
57
        if hasattr(r, '_fix_weakref'):
58
            r._fix_weakref()  # type: ignore[union-attr]
59
        return r
60

61
    def __hash__(self):
62
        return self._id
63

64
    def __eq__(self, other):
65
        # An attractive but wrong alternate implementation is to only test if
66
        # the stored _ids match.  This can lead to an ABA problem if you have:
67
        #
68
        #   a1 = A()
69
        #   w1 = WeakIdRef(a1)
70
        #   del a1
71
        #   a2 = A()  # suppose it gets the same ID as a1
72
        #   w2 = WeakIdRef(a2)
73
        #   print(w1 == w2)
74
        #
75
        # This should be False, as a1 and a2 are unrelated (and a1 is
76
        # dead anyway)
77
        a = self()
78
        b = other()
79
        if a is not None and b is not None:
80
            return a is b
81
        return self is other
82

83
# This is the same as WeakIdRef but equality is checked using hash() rather than id.
84
# This will be equivalent to the one above except for classes where hash is not their id.
85
class _WeakHashRef(weakref.ref):
86
    __slots__ = ['_id']
87

88
    def __init__(self, key, callback=None):
89
        # Unlike stock weakref, which preserves hash semantics of the
90
        # original object but lazily defers hash calls until the first
91
        # time the user attempts to hash the weakref, we can eagerly
92
        # cache the id of the key as we know this is definitely the hash
93
        # method
94
        self._id = hash(key)
95
        super().__init__(key, callback)  # type: ignore[call-arg]
96

97
    def __call__(self):
98
        r = super().__call__()
99
        # Special logic for Tensor PyObject resurrection
100
        if hasattr(r, '_fix_weakref'):
101
            r._fix_weakref()  # type: ignore[union-attr]
102
        return r
103

104
    def __hash__(self):
105
        return self._id
106

107
    def __eq__(self, other):
108
        # Use hash equality to determine ref equality.
109
        # ScriptObject implements __hash__ to return the wrapped IValue's id, so
110
        # this is equivalent to doing an identity comparison.
111
        a = self()
112
        b = other()
113
        if a is not None and b is not None:
114
            return hash(a) == hash(b)
115
        return self is other
116

117
# This is directly adapted from cpython/Lib/weakref.py
118
class WeakIdKeyDictionary(MutableMapping):
119
    def __init__(self, dict=None, ref_type=WeakIdRef):  # CHANGED
120
        self.data = {}
121

122
        self.ref_type = ref_type  # CHANGED
123

124
        def remove(k, selfref=ref(self)):
125
            self = selfref()
126
            if self is not None:
127
                if self._iterating:
128
                    self._pending_removals.append(k)
129
                else:
130
                    try:
131
                        del self.data[k]
132
                    except KeyError:
133
                        pass
134
        self._remove = remove
135
        # A list of dead weakrefs (keys to be removed)
136
        self._pending_removals = []
137
        self._iterating = set()
138
        self._dirty_len = False
139
        if dict is not None:
140
            self.update(dict)
141

142
    def _commit_removals(self):
143
        # NOTE: We don't need to call this method before mutating the dict,
144
        # because a dead weakref never compares equal to a live weakref,
145
        # even if they happened to refer to equal objects.
146
        # However, it means keys may already have been removed.
147
        pop = self._pending_removals.pop
148
        d = self.data
149
        while True:
150
            try:
151
                key = pop()
152
            except IndexError:
153
                return
154

155
            try:
156
                del d[key]
157
            except KeyError:
158
                pass
159

160
    def _scrub_removals(self):
161
        d = self.data
162
        self._pending_removals = [k for k in self._pending_removals if k in d]
163
        self._dirty_len = False
164

165
    def __delitem__(self, key):
166
        self._dirty_len = True
167
        del self.data[self.ref_type(key)]  # CHANGED
168

169
    def __getitem__(self, key):
170
        return self.data[self.ref_type(key)]  # CHANGED
171

172
    def __len__(self):
173
        if self._dirty_len and self._pending_removals:
174
            # self._pending_removals may still contain keys which were
175
            # explicitly removed, we have to scrub them (see issue #21173).
176
            self._scrub_removals()
177
        return len(self.data) - len(self._pending_removals)
178

179
    def __repr__(self):
180
        return f"<{self.__class__.__name__} at {id(self):#x}>"
181

182
    def __setitem__(self, key, value):
183
        self.data[self.ref_type(key, self._remove)] = value  # CHANGED
184

185
    def copy(self):
186
        new = WeakIdKeyDictionary()
187
        with _IterationGuard(self):
188
            for key, value in self.data.items():
189
                o = key()
190
                if o is not None:
191
                    new[o] = value
192
        return new
193

194
    __copy__ = copy
195

196
    def __deepcopy__(self, memo):
197
        from copy import deepcopy
198
        new = self.__class__()
199
        with _IterationGuard(self):
200
            for key, value in self.data.items():
201
                o = key()
202
                if o is not None:
203
                    new[o] = deepcopy(value, memo)
204
        return new
205

206
    def get(self, key, default=None):
207
        return self.data.get(self.ref_type(key), default)  # CHANGED
208

209
    def __contains__(self, key):
210
        try:
211
            wr = self.ref_type(key)  # CHANGED
212
        except TypeError:
213
            return False
214
        return wr in self.data
215

216
    def items(self):
217
        with _IterationGuard(self):
218
            for wr, value in self.data.items():
219
                key = wr()
220
                if key is not None:
221
                    yield key, value
222

223
    def keys(self):
224
        with _IterationGuard(self):
225
            for wr in self.data:
226
                obj = wr()
227
                if obj is not None:
228
                    yield obj
229

230
    __iter__ = keys
231

232
    def values(self):
233
        with _IterationGuard(self):
234
            for wr, value in self.data.items():
235
                if wr() is not None:
236
                    yield value
237

238
    def keyrefs(self):
239
        """Return a list of weak references to the keys.
240

241
        The references are not guaranteed to be 'live' at the time
242
        they are used, so the result of calling the references needs
243
        to be checked before being used.  This can be used to avoid
244
        creating references that will cause the garbage collector to
245
        keep the keys around longer than needed.
246

247
        """
248
        return list(self.data)
249

250
    def popitem(self):
251
        self._dirty_len = True
252
        while True:
253
            key, value = self.data.popitem()
254
            o = key()
255
            if o is not None:
256
                return o, value
257

258
    def pop(self, key, *args):
259
        self._dirty_len = True
260
        return self.data.pop(self.ref_type(key), *args)  # CHANGED
261

262
    def setdefault(self, key, default=None):
263
        return self.data.setdefault(self.ref_type(key, self._remove), default)  # CHANGED
264

265
    def update(self, dict=None, **kwargs):
266
        d = self.data
267
        if dict is not None:
268
            if not hasattr(dict, "items"):
269
                dict = type({})(dict)
270
            for key, value in dict.items():
271
                d[self.ref_type(key, self._remove)] = value  # CHANGED
272
        if len(kwargs):
273
            self.update(kwargs)
274

275
    def __ior__(self, other):
276
        self.update(other)
277
        return self
278

279
    def __or__(self, other):
280
        if isinstance(other, _collections_abc.Mapping):
281
            c = self.copy()
282
            c.update(other)
283
            return c
284
        return NotImplemented
285

286
    def __ror__(self, other):
287
        if isinstance(other, _collections_abc.Mapping):
288
            c = self.__class__()
289
            c.update(other)
290
            c.update(self)
291
            return c
292
        return NotImplemented
293

294
    # Default Mapping equality will tests keys for equality, but
295
    # we want to test ids for equality
296
    def __eq__(self, other):
297
        if not isinstance(other, Mapping):
298
            return NotImplemented
299
        return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
300

301
# Convenience alias
302
WeakTensorKeyDictionary = WeakIdKeyDictionary
303

304

305
class TensorWeakRef:
306
    """Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
307

308
    ref: WeakRef[Tensor]
309

310
    def __init__(self, tensor: Tensor):
311
        assert isinstance(tensor, Tensor)
312
        self.ref = weakref.ref(tensor)
313

314
    def __call__(self):
315
        out = self.ref()
316
        if out is None:
317
            return out
318
        assert isinstance(out, Tensor)
319
        # TODO, add _fix_weakref type binding
320
        out._fix_weakref()  # type: ignore[attr-defined]
321
        return out
322

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.