19
import functools as _functools
20
from collections import OrderedDict
60
from struct import unpack
61
from sys import maxsize
62
from typing import Any, Dict, List
68
@_functools.lru_cache(maxsize=1)
69
def _get_allowed_globals():
70
rc: Dict[str, Any] = {
71
"collections.OrderedDict": OrderedDict,
72
"torch.nn.parameter.Parameter": torch.nn.Parameter,
73
"torch.serialization._get_layout": torch.serialization._get_layout,
74
"torch.Size": torch.Size,
75
"torch.Tensor": torch.Tensor,
84
torch.float8_e5m2fnuz,
85
torch.float8_e4m3fnuz,
96
for tt in torch._tensor_classes:
97
rc[f"{tt.__module__}.{tt.__name__}"] = tt
99
for ts in torch._storage_classes:
100
if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
102
rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
106
rc[f"{ts.__module__}.{ts.__name__}"] = ts
109
torch._utils._rebuild_parameter,
110
torch._utils._rebuild_tensor,
111
torch._utils._rebuild_tensor_v2,
112
torch._utils._rebuild_tensor_v3,
113
torch._utils._rebuild_sparse_tensor,
114
torch._utils._rebuild_meta_tensor_no_storage,
115
torch._utils._rebuild_nested_tensor,
117
rc[f"torch._utils.{f.__name__}"] = f
121
rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
126
def __init__(self, file, *, encoding: str = "bytes"):
127
self.encoding = encoding
128
self.readline = file.readline
129
self.read = file.read
130
self.memo: Dict[int, Any] = {}
133
"""Read a pickled object representation from the open file.
135
Return the reconstituted object hierarchy specified in the file.
138
self.stack: List[Any] = []
139
self.append = self.stack.append
141
readline = self.readline
146
assert isinstance(key, bytes_types)
148
if key[0] == GLOBAL[0]:
149
module = readline()[:-1].decode("utf-8")
150
name = readline()[:-1].decode("utf-8")
151
full_path = f"{module}.{name}"
152
if full_path in _get_allowed_globals():
153
self.append(_get_allowed_globals()[full_path])
155
raise RuntimeError(f"Unsupported class {full_path}")
156
elif key[0] == NEWOBJ[0]:
157
args = self.stack.pop()
158
cls = self.stack.pop()
159
if cls is not torch.nn.Parameter:
160
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
161
self.append(torch.nn.Parameter(*args))
162
elif key[0] == REDUCE[0]:
163
args = self.stack.pop()
164
func = self.stack[-1]
165
if func not in _get_allowed_globals().values():
167
f"Trying to call reduce for unrecognized function {func}"
169
self.stack[-1] = func(*args)
170
elif key[0] == BUILD[0]:
171
state = self.stack.pop()
172
inst = self.stack[-1]
173
if type(inst) is torch.Tensor:
176
elif type(inst) is torch.nn.Parameter:
177
inst.__setstate__(state)
178
elif type(inst) is OrderedDict:
179
inst.__dict__.update(state)
182
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
185
elif key[0] == APPEND[0]:
186
item = self.stack.pop()
187
list_obj = self.stack[-1]
188
if type(list_obj) is not list:
190
f"Can only append to lists, but got {type(list_obj)}"
192
list_obj.append(item)
193
elif key[0] == APPENDS[0]:
194
items = self.pop_mark()
195
list_obj = self.stack[-1]
196
if type(list_obj) is not list:
198
f"Can only extend lists, but got {type(list_obj)}"
200
list_obj.extend(items)
201
elif key[0] == SETITEM[0]:
202
(v, k) = (self.stack.pop(), self.stack.pop())
203
self.stack[-1][k] = v
204
elif key[0] == SETITEMS[0]:
205
items = self.pop_mark()
206
for i in range(0, len(items), 2):
207
self.stack[-1][items[i]] = items[i + 1]
208
elif key[0] == MARK[0]:
209
self.metastack.append(self.stack)
211
self.append = self.stack.append
212
elif key[0] == TUPLE[0]:
213
items = self.pop_mark()
214
self.append(tuple(items))
215
elif key[0] == TUPLE1[0]:
216
self.stack[-1] = (self.stack[-1],)
217
elif key[0] == TUPLE2[0]:
218
self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
219
elif key[0] == TUPLE3[0]:
220
self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
222
elif key[0] == NONE[0]:
224
elif key[0] == NEWFALSE[0]:
226
elif key[0] == NEWTRUE[0]:
228
elif key[0] == EMPTY_TUPLE[0]:
230
elif key[0] == EMPTY_LIST[0]:
232
elif key[0] == EMPTY_DICT[0]:
234
elif key[0] == EMPTY_SET[0]:
236
elif key[0] == BININT[0]:
237
self.append(unpack("<i", read(4))[0])
238
elif key[0] == BININT1[0]:
239
self.append(self.read(1)[0])
240
elif key[0] == BININT2[0]:
241
self.append(unpack("<H", read(2))[0])
242
elif key[0] == BINFLOAT[0]:
243
self.append(unpack(">d", self.read(8))[0])
244
elif key[0] == BINUNICODE[0]:
245
strlen = unpack("<I", read(4))[0]
247
raise RuntimeError("String is too long")
248
strval = str(read(strlen), "utf-8", "surrogatepass")
250
elif key[0] == SHORT_BINSTRING[0]:
252
strdata = read(strlen)
253
if self.encoding != "bytes":
254
strdata = strdata.decode(self.encoding, "strict")
256
elif key[0] == BINPERSID[0]:
257
pid = self.stack.pop()
259
if type(pid) is not tuple and not type(pid) is not int:
261
f"persistent_load id must be tuple or int, but got {type(pid)}"
266
and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
269
f"Only persistent_load of storage is allowed, but got {pid[0]}"
271
self.append(self.persistent_load(pid))
272
elif key[0] in [BINGET[0], LONG_BINGET[0]]:
273
idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
274
self.append(self.memo[idx])
275
elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
276
i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
278
raise ValueError("negative argument")
279
self.memo[i] = self.stack[-1]
280
elif key[0] == LONG1[0]:
283
self.append(decode_long(data))
285
elif key[0] == PROTO[0]:
288
elif key[0] == STOP[0]:
289
rc = self.stack.pop()
292
raise RuntimeError(f"Unsupported operand {key[0]}")
297
self.stack = self.metastack.pop()
298
self.append = self.stack.append
301
def persistent_load(self, pid):
302
raise UnpicklingError("unsupported persistent id encountered")
305
def load(file, *, encoding: str = "ASCII"):
306
return Unpickler(file, encoding=encoding).load()