pytorch

Форк
0
/
_weights_only_unpickler.py 
306 строк · 10.9 Кб
1
# Unpickler restricted to loading only state dicts
2
# Restrict constructing types to a list defined in _get_allowed_globals()
3
# Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
4
# Restrict APPEND/APPENDS to `list`
5
# In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
6
# defined by `_get_allowed_globals()` method, that contains:
7
# - torch types (Storage, dtypes, Tensor, `torch.Size`),
8
# - `torch._utils._rebuild` functions.
9
# - `torch.nn.Parameter`
10
# - `collections.OrderedDict`
11

12
# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
13
# Expected to be useful for loading PyTorch model weights
14
# For example:
15
# data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
16
# buf = io.BytesIO(data)
17
# weights = torch.load(buf, weights_only = True)
18

19
import functools as _functools
20
from collections import OrderedDict
21
from pickle import (
22
    APPEND,
23
    APPENDS,
24
    BINFLOAT,
25
    BINGET,
26
    BININT,
27
    BININT1,
28
    BININT2,
29
    BINPERSID,
30
    BINPUT,
31
    BINUNICODE,
32
    BUILD,
33
    bytes_types,
34
    decode_long,
35
    EMPTY_DICT,
36
    EMPTY_LIST,
37
    EMPTY_SET,
38
    EMPTY_TUPLE,
39
    GLOBAL,
40
    LONG1,
41
    LONG_BINGET,
42
    LONG_BINPUT,
43
    MARK,
44
    NEWFALSE,
45
    NEWOBJ,
46
    NEWTRUE,
47
    NONE,
48
    PROTO,
49
    REDUCE,
50
    SETITEM,
51
    SETITEMS,
52
    SHORT_BINSTRING,
53
    STOP,
54
    TUPLE,
55
    TUPLE1,
56
    TUPLE2,
57
    TUPLE3,
58
    UnpicklingError,
59
)
60
from struct import unpack
61
from sys import maxsize
62
from typing import Any, Dict, List
63

64
import torch
65

66

67
# Unpickling machinery
68
@_functools.lru_cache(maxsize=1)
69
def _get_allowed_globals():
70
    rc: Dict[str, Any] = {
71
        "collections.OrderedDict": OrderedDict,
72
        "torch.nn.parameter.Parameter": torch.nn.Parameter,
73
        "torch.serialization._get_layout": torch.serialization._get_layout,
74
        "torch.Size": torch.Size,
75
        "torch.Tensor": torch.Tensor,
76
    }
77
    # dtype
78
    for t in [
79
        torch.complex32,
80
        torch.complex64,
81
        torch.complex128,
82
        torch.float8_e5m2,
83
        torch.float8_e4m3fn,
84
        torch.float8_e5m2fnuz,
85
        torch.float8_e4m3fnuz,
86
        torch.float16,
87
        torch.float32,
88
        torch.float64,
89
        torch.int8,
90
        torch.int16,
91
        torch.int32,
92
        torch.int64,
93
    ]:
94
        rc[str(t)] = t
95
    # Tensor classes
96
    for tt in torch._tensor_classes:
97
        rc[f"{tt.__module__}.{tt.__name__}"] = tt
98
    # Storage classes
99
    for ts in torch._storage_classes:
100
        if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
101
            # Wrap legacy storage types in a dummy class
102
            rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
103
                ts.__name__
104
            )
105
        else:
106
            rc[f"{ts.__module__}.{ts.__name__}"] = ts
107
    # Rebuild functions
108
    for f in [
109
        torch._utils._rebuild_parameter,
110
        torch._utils._rebuild_tensor,
111
        torch._utils._rebuild_tensor_v2,
112
        torch._utils._rebuild_tensor_v3,
113
        torch._utils._rebuild_sparse_tensor,
114
        torch._utils._rebuild_meta_tensor_no_storage,
115
        torch._utils._rebuild_nested_tensor,
116
    ]:
117
        rc[f"torch._utils.{f.__name__}"] = f
118

119
    # Handles Tensor Subclasses, Tensor's with attributes.
120
    # NOTE: It calls into above rebuild functions for regular Tensor types.
121
    rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
122
    return rc
123

124

125
class Unpickler:
126
    def __init__(self, file, *, encoding: str = "bytes"):
127
        self.encoding = encoding
128
        self.readline = file.readline
129
        self.read = file.read
130
        self.memo: Dict[int, Any] = {}
131

132
    def load(self):
133
        """Read a pickled object representation from the open file.
134

135
        Return the reconstituted object hierarchy specified in the file.
136
        """
137
        self.metastack = []
138
        self.stack: List[Any] = []
139
        self.append = self.stack.append
140
        read = self.read
141
        readline = self.readline
142
        while True:
143
            key = read(1)
144
            if not key:
145
                raise EOFError
146
            assert isinstance(key, bytes_types)
147
            # Risky operators
148
            if key[0] == GLOBAL[0]:
149
                module = readline()[:-1].decode("utf-8")
150
                name = readline()[:-1].decode("utf-8")
151
                full_path = f"{module}.{name}"
152
                if full_path in _get_allowed_globals():
153
                    self.append(_get_allowed_globals()[full_path])
154
                else:
155
                    raise RuntimeError(f"Unsupported class {full_path}")
156
            elif key[0] == NEWOBJ[0]:
157
                args = self.stack.pop()
158
                cls = self.stack.pop()
159
                if cls is not torch.nn.Parameter:
160
                    raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
161
                self.append(torch.nn.Parameter(*args))
162
            elif key[0] == REDUCE[0]:
163
                args = self.stack.pop()
164
                func = self.stack[-1]
165
                if func not in _get_allowed_globals().values():
166
                    raise RuntimeError(
167
                        f"Trying to call reduce for unrecognized function {func}"
168
                    )
169
                self.stack[-1] = func(*args)
170
            elif key[0] == BUILD[0]:
171
                state = self.stack.pop()
172
                inst = self.stack[-1]
173
                if type(inst) is torch.Tensor:
174
                    # Legacy unpickling
175
                    inst.set_(*state)
176
                elif type(inst) is torch.nn.Parameter:
177
                    inst.__setstate__(state)
178
                elif type(inst) is OrderedDict:
179
                    inst.__dict__.update(state)
180
                else:
181
                    raise RuntimeError(
182
                        f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
183
                    )
184
            # Stack manipulation
185
            elif key[0] == APPEND[0]:
186
                item = self.stack.pop()
187
                list_obj = self.stack[-1]
188
                if type(list_obj) is not list:
189
                    raise RuntimeError(
190
                        f"Can only append to lists, but got {type(list_obj)}"
191
                    )
192
                list_obj.append(item)
193
            elif key[0] == APPENDS[0]:
194
                items = self.pop_mark()
195
                list_obj = self.stack[-1]
196
                if type(list_obj) is not list:
197
                    raise RuntimeError(
198
                        f"Can only extend lists, but got {type(list_obj)}"
199
                    )
200
                list_obj.extend(items)
201
            elif key[0] == SETITEM[0]:
202
                (v, k) = (self.stack.pop(), self.stack.pop())
203
                self.stack[-1][k] = v
204
            elif key[0] == SETITEMS[0]:
205
                items = self.pop_mark()
206
                for i in range(0, len(items), 2):
207
                    self.stack[-1][items[i]] = items[i + 1]
208
            elif key[0] == MARK[0]:
209
                self.metastack.append(self.stack)
210
                self.stack = []
211
                self.append = self.stack.append
212
            elif key[0] == TUPLE[0]:
213
                items = self.pop_mark()
214
                self.append(tuple(items))
215
            elif key[0] == TUPLE1[0]:
216
                self.stack[-1] = (self.stack[-1],)
217
            elif key[0] == TUPLE2[0]:
218
                self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
219
            elif key[0] == TUPLE3[0]:
220
                self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
221
            # Basic types construction
222
            elif key[0] == NONE[0]:
223
                self.append(None)
224
            elif key[0] == NEWFALSE[0]:
225
                self.append(False)
226
            elif key[0] == NEWTRUE[0]:
227
                self.append(True)
228
            elif key[0] == EMPTY_TUPLE[0]:
229
                self.append(())
230
            elif key[0] == EMPTY_LIST[0]:
231
                self.append([])
232
            elif key[0] == EMPTY_DICT[0]:
233
                self.append({})
234
            elif key[0] == EMPTY_SET[0]:
235
                self.append(set())
236
            elif key[0] == BININT[0]:
237
                self.append(unpack("<i", read(4))[0])
238
            elif key[0] == BININT1[0]:
239
                self.append(self.read(1)[0])
240
            elif key[0] == BININT2[0]:
241
                self.append(unpack("<H", read(2))[0])
242
            elif key[0] == BINFLOAT[0]:
243
                self.append(unpack(">d", self.read(8))[0])
244
            elif key[0] == BINUNICODE[0]:
245
                strlen = unpack("<I", read(4))[0]
246
                if strlen > maxsize:
247
                    raise RuntimeError("String is too long")
248
                strval = str(read(strlen), "utf-8", "surrogatepass")
249
                self.append(strval)
250
            elif key[0] == SHORT_BINSTRING[0]:
251
                strlen = read(1)[0]
252
                strdata = read(strlen)
253
                if self.encoding != "bytes":
254
                    strdata = strdata.decode(self.encoding, "strict")
255
                self.append(strdata)
256
            elif key[0] == BINPERSID[0]:
257
                pid = self.stack.pop()
258
                # Only allow persistent load of storage
259
                if type(pid) is not tuple and not type(pid) is not int:
260
                    raise RuntimeError(
261
                        f"persistent_load id must be tuple or int, but got {type(pid)}"
262
                    )
263
                if (
264
                    type(pid) is tuple
265
                    and len(pid) > 0
266
                    and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
267
                ):
268
                    raise RuntimeError(
269
                        f"Only persistent_load of storage is allowed, but got {pid[0]}"
270
                    )
271
                self.append(self.persistent_load(pid))
272
            elif key[0] in [BINGET[0], LONG_BINGET[0]]:
273
                idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
274
                self.append(self.memo[idx])
275
            elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
276
                i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
277
                if i < 0:
278
                    raise ValueError("negative argument")
279
                self.memo[i] = self.stack[-1]
280
            elif key[0] == LONG1[0]:
281
                n = read(1)[0]
282
                data = read(n)
283
                self.append(decode_long(data))
284
            # First and last deserializer ops
285
            elif key[0] == PROTO[0]:
286
                # Read and ignore proto version
287
                read(1)[0]
288
            elif key[0] == STOP[0]:
289
                rc = self.stack.pop()
290
                return rc
291
            else:
292
                raise RuntimeError(f"Unsupported operand {key[0]}")
293

294
    # Return a list of items pushed in the stack after last MARK instruction.
295
    def pop_mark(self):
296
        items = self.stack
297
        self.stack = self.metastack.pop()
298
        self.append = self.stack.append
299
        return items
300

301
    def persistent_load(self, pid):
302
        raise UnpicklingError("unsupported persistent id encountered")
303

304

305
def load(file, *, encoding: str = "ASCII"):
306
    return Unpickler(file, encoding=encoding).load()
307

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.