pytorch

Форк
0
/
show_pickle.py 
150 строк · 5.2 Кб
1
#!/usr/bin/env python3
2
import sys
3
import pickle
4
import struct
5
import pprint
6
import zipfile
7
import fnmatch
8
from typing import Any, IO, BinaryIO, Union
9

10
__all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"]
11

12
class FakeObject:
13
    def __init__(self, module, name, args):
14
        self.module = module
15
        self.name = name
16
        self.args = args
17
        # NOTE: We don't distinguish between state never set and state set to None.
18
        self.state = None
19

20
    def __repr__(self):
21
        state_str = "" if self.state is None else f"(state={self.state!r})"
22
        return f"{self.module}.{self.name}{self.args!r}{state_str}"
23

24
    def __setstate__(self, state):
25
        self.state = state
26

27
    @staticmethod
28
    def pp_format(printer, obj, stream, indent, allowance, context, level):
29
        if not obj.args and obj.state is None:
30
            stream.write(repr(obj))
31
            return
32
        if obj.state is None:
33
            stream.write(f"{obj.module}.{obj.name}")
34
            printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
35
            return
36
        if not obj.args:
37
            stream.write(f"{obj.module}.{obj.name}()(state=\n")
38
            indent += printer._indent_per_level
39
            stream.write(" " * indent)
40
            printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
41
            stream.write(")")
42
            return
43
        raise Exception("Need to implement")
44

45

46
class FakeClass:
47
    def __init__(self, module, name):
48
        self.module = module
49
        self.name = name
50
        self.__new__ = self.fake_new  # type: ignore[assignment]
51

52
    def __repr__(self):
53
        return f"{self.module}.{self.name}"
54

55
    def __call__(self, *args):
56
        return FakeObject(self.module, self.name, args)
57

58
    def fake_new(self, *args):
59
        return FakeObject(self.module, self.name, args[1:])
60

61

62
class DumpUnpickler(pickle._Unpickler):  # type: ignore[name-defined]
63
    def __init__(
64
            self,
65
            file,
66
            *,
67
            catch_invalid_utf8=False,
68
            **kwargs):
69
        super().__init__(file, **kwargs)
70
        self.catch_invalid_utf8 = catch_invalid_utf8
71

72
    def find_class(self, module, name):
73
        return FakeClass(module, name)
74

75
    def persistent_load(self, pid):
76
        return FakeObject("pers", "obj", (pid,))
77

78
    dispatch = dict(pickle._Unpickler.dispatch)  # type: ignore[attr-defined]
79

80
    # Custom objects in TorchScript are able to return invalid UTF-8 strings
81
    # from their pickle (__getstate__) functions.  Install a custom loader
82
    # for strings that catches the decode exception and replaces it with
83
    # a sentinel object.
84
    def load_binunicode(self):
85
        strlen, = struct.unpack("<I", self.read(4))  # type: ignore[attr-defined]
86
        if strlen > sys.maxsize:
87
            raise Exception("String too long.")
88
        str_bytes = self.read(strlen)  # type: ignore[attr-defined]
89
        obj: Any
90
        try:
91
            obj = str(str_bytes, "utf-8", "surrogatepass")
92
        except UnicodeDecodeError as exn:
93
            if not self.catch_invalid_utf8:
94
                raise
95
            obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),))
96
        self.append(obj)  # type: ignore[attr-defined]
97
    dispatch[pickle.BINUNICODE[0]] = load_binunicode  # type: ignore[assignment]
98

99
    @classmethod
100
    def dump(cls, in_stream, out_stream):
101
        value = cls(in_stream).load()
102
        pprint.pprint(value, stream=out_stream)
103
        return value
104

105

106
def main(argv, output_stream=None):
107
    if len(argv) != 2:
108
        # Don't spam stderr if not using stdout.
109
        if output_stream is not None:
110
            raise Exception("Pass argv of length 2.")
111
        sys.stderr.write("usage: show_pickle PICKLE_FILE\n")
112
        sys.stderr.write("  PICKLE_FILE can be any of:\n")
113
        sys.stderr.write("    path to a pickle file\n")
114
        sys.stderr.write("    file.zip@member.pkl\n")
115
        sys.stderr.write("    file.zip@*/pattern.*\n")
116
        sys.stderr.write("      (shell glob pattern for members)\n")
117
        sys.stderr.write("      (only first match will be shown)\n")
118
        return 2
119

120
    fname = argv[1]
121
    handle: Union[IO[bytes], BinaryIO]
122
    if "@" not in fname:
123
        with open(fname, "rb") as handle:
124
            DumpUnpickler.dump(handle, output_stream)
125
    else:
126
        zfname, mname = fname.split("@", 1)
127
        with zipfile.ZipFile(zfname) as zf:
128
            if "*" not in mname:
129
                with zf.open(mname) as handle:
130
                    DumpUnpickler.dump(handle, output_stream)
131
            else:
132
                found = False
133
                for info in zf.infolist():
134
                    if fnmatch.fnmatch(info.filename, mname):
135
                        with zf.open(info) as handle:
136
                            DumpUnpickler.dump(handle, output_stream)
137
                        found = True
138
                        break
139
                if not found:
140
                    raise Exception(f"Could not find member matching {mname} in {zfname}")
141

142

143
if __name__ == "__main__":
144
    # This hack works on every version of Python I've tested.
145
    # I've tested on the following versions:
146
    #   3.7.4
147
    if True:
148
        pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format  # type: ignore[attr-defined]
149

150
    sys.exit(main(sys.argv))
151

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.