pytorch

Форк
0
/
weak.py 
322 строки · 10.8 Кб
1
# mypy: allow-untyped-defs
2
from __future__ import annotations
3

4
import weakref
5
from weakref import ref
6
from _weakrefset import _IterationGuard  # type: ignore[attr-defined]
7
from collections.abc import MutableMapping, Mapping
8
from torch import Tensor
9
import collections.abc as _collections_abc
10

11

12
WeakRef = ref
13

14

15
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
16

17

18
# This file defines a variant of WeakKeyDictionary that overrides the hashing
19
# behavior of the key to use object identity, rather than the builtin
20
# __eq__/__hash__ functions.  This is useful for Tensor weak keys, as their
21
# __eq__ implementation return a Tensor (elementwise equality), which means
22
# you can't use them directly with the WeakKeyDictionary in standard library.
23
#
24
# Our implementation strategy is to create a wrapper weak key object, which we
25
# use as a key in a stock Python dictionary.  This is similar to how weakref
26
# implements WeakKeyDictionary, but instead of using weakref.ref as the
27
# wrapper, we use a custom wrapper that has different __eq__ and __hash__
28
# behavior.  Note that we subsequently store this weak key directly in an
29
# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
30
# be a dictionary so it would have no strong references.  Ensuring that
31
# only live WeakIdKeys are in the map is handled by putting finalizers on the
32
# original key object.
33

34

35
# It is simpler to implement this with composition, but if we want to
36
# directly reuse the callback mechanism on weakref, we need the weakref
37
# and the key to be exactly the same object.  Reusing the callback mechanism
38
# minimizes the divergence between our implementation and Lib/weakref.py
39
#
40
# NB: Prefer using this when working with weakrefs of Tensors; e.g., do
41
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
42
# easy to get wrong cases transparently for you.
43
class WeakIdRef(weakref.ref):
44
    __slots__ = ['_id']
45

46
    def __init__(self, key, callback=None):
47
        # Unlike stock weakref, which preserves hash semantics of the
48
        # original object but lazily defers hash calls until the first
49
        # time the user attempts to hash the weakref, we can eagerly
50
        # cache the id of the key as we know this is definitely the hash
51
        # method
52
        self._id = id(key)
53
        super().__init__(key, callback)  # type: ignore[call-arg]
54

55
    def __call__(self):
56
        r = super().__call__()
57
        # Special logic for Tensor PyObject resurrection
58
        if hasattr(r, '_fix_weakref'):
59
            r._fix_weakref()  # type: ignore[union-attr]
60
        return r
61

62
    def __hash__(self):
63
        return self._id
64

65
    def __eq__(self, other):
66
        # An attractive but wrong alternate implementation is to only test if
67
        # the stored _ids match.  This can lead to an ABA problem if you have:
68
        #
69
        #   a1 = A()
70
        #   w1 = WeakIdRef(a1)
71
        #   del a1
72
        #   a2 = A()  # suppose it gets the same ID as a1
73
        #   w2 = WeakIdRef(a2)
74
        #   print(w1 == w2)
75
        #
76
        # This should be False, as a1 and a2 are unrelated (and a1 is
77
        # dead anyway)
78
        a = self()
79
        b = other()
80
        if a is not None and b is not None:
81
            return a is b
82
        return self is other
83

84
# This is the same as WeakIdRef but equality is checked using hash() rather than id.
85
# This will be equivalent to the one above except for classes where hash is not their id.
86
class _WeakHashRef(weakref.ref):
87
    __slots__ = ['_id']
88

89
    def __init__(self, key, callback=None):
90
        # Unlike stock weakref, which preserves hash semantics of the
91
        # original object but lazily defers hash calls until the first
92
        # time the user attempts to hash the weakref, we can eagerly
93
        # cache the id of the key as we know this is definitely the hash
94
        # method
95
        self._id = hash(key)
96
        super().__init__(key, callback)  # type: ignore[call-arg]
97

98
    def __call__(self):
99
        r = super().__call__()
100
        # Special logic for Tensor PyObject resurrection
101
        if hasattr(r, '_fix_weakref'):
102
            r._fix_weakref()  # type: ignore[union-attr]
103
        return r
104

105
    def __hash__(self):
106
        return self._id
107

108
    def __eq__(self, other):
109
        # Use hash equality to determine ref equality.
110
        # ScriptObject implements __hash__ to return the wrapped IValue's id, so
111
        # this is equivalent to doing an identity comparison.
112
        a = self()
113
        b = other()
114
        if a is not None and b is not None:
115
            return hash(a) == hash(b)
116
        return self is other
117

118
# This is directly adapted from cpython/Lib/weakref.py
119
class WeakIdKeyDictionary(MutableMapping):
120
    def __init__(self, dict=None, ref_type=WeakIdRef):  # CHANGED
121
        self.data = {}
122

123
        self.ref_type = ref_type  # CHANGED
124

125
        def remove(k, selfref=ref(self)):
126
            self = selfref()
127
            if self is not None:
128
                if self._iterating:
129
                    self._pending_removals.append(k)
130
                else:
131
                    try:
132
                        del self.data[k]
133
                    except KeyError:
134
                        pass
135
        self._remove = remove
136
        # A list of dead weakrefs (keys to be removed)
137
        self._pending_removals = []
138
        self._iterating = set()
139
        self._dirty_len = False
140
        if dict is not None:
141
            self.update(dict)
142

143
    def _commit_removals(self):
144
        # NOTE: We don't need to call this method before mutating the dict,
145
        # because a dead weakref never compares equal to a live weakref,
146
        # even if they happened to refer to equal objects.
147
        # However, it means keys may already have been removed.
148
        pop = self._pending_removals.pop
149
        d = self.data
150
        while True:
151
            try:
152
                key = pop()
153
            except IndexError:
154
                return
155

156
            try:
157
                del d[key]
158
            except KeyError:
159
                pass
160

161
    def _scrub_removals(self):
162
        d = self.data
163
        self._pending_removals = [k for k in self._pending_removals if k in d]
164
        self._dirty_len = False
165

166
    def __delitem__(self, key):
167
        self._dirty_len = True
168
        del self.data[self.ref_type(key)]  # CHANGED
169

170
    def __getitem__(self, key):
171
        return self.data[self.ref_type(key)]  # CHANGED
172

173
    def __len__(self):
174
        if self._dirty_len and self._pending_removals:
175
            # self._pending_removals may still contain keys which were
176
            # explicitly removed, we have to scrub them (see issue #21173).
177
            self._scrub_removals()
178
        return len(self.data) - len(self._pending_removals)
179

180
    def __repr__(self):
181
        return f"<{self.__class__.__name__} at {id(self):#x}>"
182

183
    def __setitem__(self, key, value):
184
        self.data[self.ref_type(key, self._remove)] = value  # CHANGED
185

186
    def copy(self):
187
        new = WeakIdKeyDictionary()
188
        with _IterationGuard(self):
189
            for key, value in self.data.items():
190
                o = key()
191
                if o is not None:
192
                    new[o] = value
193
        return new
194

195
    __copy__ = copy
196

197
    def __deepcopy__(self, memo):
198
        from copy import deepcopy
199
        new = self.__class__()
200
        with _IterationGuard(self):
201
            for key, value in self.data.items():
202
                o = key()
203
                if o is not None:
204
                    new[o] = deepcopy(value, memo)
205
        return new
206

207
    def get(self, key, default=None):
208
        return self.data.get(self.ref_type(key), default)  # CHANGED
209

210
    def __contains__(self, key):
211
        try:
212
            wr = self.ref_type(key)  # CHANGED
213
        except TypeError:
214
            return False
215
        return wr in self.data
216

217
    def items(self):
218
        with _IterationGuard(self):
219
            for wr, value in self.data.items():
220
                key = wr()
221
                if key is not None:
222
                    yield key, value
223

224
    def keys(self):
225
        with _IterationGuard(self):
226
            for wr in self.data:
227
                obj = wr()
228
                if obj is not None:
229
                    yield obj
230

231
    __iter__ = keys
232

233
    def values(self):
234
        with _IterationGuard(self):
235
            for wr, value in self.data.items():
236
                if wr() is not None:
237
                    yield value
238

239
    def keyrefs(self):
240
        """Return a list of weak references to the keys.
241

242
        The references are not guaranteed to be 'live' at the time
243
        they are used, so the result of calling the references needs
244
        to be checked before being used.  This can be used to avoid
245
        creating references that will cause the garbage collector to
246
        keep the keys around longer than needed.
247

248
        """
249
        return list(self.data)
250

251
    def popitem(self):
252
        self._dirty_len = True
253
        while True:
254
            key, value = self.data.popitem()
255
            o = key()
256
            if o is not None:
257
                return o, value
258

259
    def pop(self, key, *args):
260
        self._dirty_len = True
261
        return self.data.pop(self.ref_type(key), *args)  # CHANGED
262

263
    def setdefault(self, key, default=None):
264
        return self.data.setdefault(self.ref_type(key, self._remove), default)  # CHANGED
265

266
    def update(self, dict=None, **kwargs):
267
        d = self.data
268
        if dict is not None:
269
            if not hasattr(dict, "items"):
270
                dict = type({})(dict)
271
            for key, value in dict.items():
272
                d[self.ref_type(key, self._remove)] = value  # CHANGED
273
        if len(kwargs):
274
            self.update(kwargs)
275

276
    def __ior__(self, other):
277
        self.update(other)
278
        return self
279

280
    def __or__(self, other):
281
        if isinstance(other, _collections_abc.Mapping):
282
            c = self.copy()
283
            c.update(other)
284
            return c
285
        return NotImplemented
286

287
    def __ror__(self, other):
288
        if isinstance(other, _collections_abc.Mapping):
289
            c = self.__class__()
290
            c.update(other)
291
            c.update(self)
292
            return c
293
        return NotImplemented
294

295
    # Default Mapping equality will tests keys for equality, but
296
    # we want to test ids for equality
297
    def __eq__(self, other):
298
        if not isinstance(other, Mapping):
299
            return NotImplemented
300
        return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
301

302
# Convenience alias
303
WeakTensorKeyDictionary = WeakIdKeyDictionary
304

305

306
class TensorWeakRef:
307
    """Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
308

309
    ref: WeakRef[Tensor]
310

311
    def __init__(self, tensor: Tensor):
312
        assert isinstance(tensor, Tensor)
313
        self.ref = weakref.ref(tensor)
314

315
    def __call__(self):
316
        out = self.ref()
317
        if out is None:
318
            return out
319
        assert isinstance(out, Tensor)
320
        # TODO, add _fix_weakref type binding
321
        out._fix_weakref()  # type: ignore[attr-defined]
322
        return out
323

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.