pytorch

Форк
0
412 строк · 16.4 Кб
1
#!/usr/bin/env python3
2
# mypy: allow-untyped-defs
3
"""
4
model_dump: a one-stop shop for TorchScript model inspection.
5

6
The goal of this tool is to provide a simple way to extract lots of
7
useful information from a TorchScript model and make it easy for humans
8
to consume.  It (mostly) replaces zipinfo, common uses of show_pickle,
9
and various ad-hoc analysis notebooks.
10

11
The tool extracts information from the model and serializes it as JSON.
12
That JSON can then be rendered by an HTML+JS page, either by
13
loading the JSON over HTTP or producing a fully self-contained page
14
with all of the code and data burned-in.
15
"""
16

17
# Maintainer notes follow.
18
"""
19
The implementation strategy has tension between 3 goals:
20
- Small file size.
21
- Fully self-contained.
22
- Easy, modern JS environment.
23
Using Preact and HTM achieves 1 and 2 with a decent result for 3.
24
However, the models I tested with result in ~1MB JSON output,
25
so even using something heavier like full React might be tolerable
26
if the build process can be worked out.
27

28
One principle I have followed that I think is very beneficial
29
is to keep the JSON data as close as possible to the model
30
and do most of the rendering logic on the client.
31
This makes for easier development (just refresh, usually),
32
allows for more laziness and dynamism, and lets us add more
33
views of the same data without bloating the HTML file.
34

35
Currently, this code doesn't actually load the model or even
36
depend on any part of PyTorch.  I don't know if that's an important
37
feature to maintain, but it's probably worth preserving the ability
38
to run at least basic analysis on models that cannot be loaded.
39

40
I think the easiest way to develop this code is to cd into model_dump and
41
run "python -m http.server", then load http://localhost:8000/skeleton.html
42
in the browser.  In another terminal, run
43
"python -m torch.utils.model_dump --style=json FILE > \
44
    torch/utils/model_dump/model_info.json"
45
every time you update the Python code or model.
46
When you update JS, just refresh.
47

48
Possible improvements:
49
    - Fix various TODO comments in this file and the JS.
50
    - Make the HTML much less janky, especially the auxiliary data panel.
51
    - Make the auxiliary data panel start small, expand when
52
      data is available, and have a button to clear/contract.
53
    - Clean up the JS.  There's a lot of copypasta because
54
      I don't really know how to use Preact.
55
    - Make the HTML render and work nicely inside a Jupyter notebook.
56
    - Add the ability for JS to choose the URL to load the JSON based
57
      on the page URL (query or hash).  That way we could publish the
58
      inlined skeleton once and have it load various JSON blobs.
59
    - Add a button to expand all expandable sections so ctrl-F works well.
60
    - Add hyperlinking from data to code, and code to code.
61
    - Add hyperlinking from debug info to Diffusion.
62
    - Make small tensor contents available.
63
    - Do something nice for quantized models
64
      (they probably don't work at all right now).
65
"""
66

67
import argparse
68
import io
69
import json
70
import os
71
import pickle
72
import pprint
73
import re
74
import sys
75
import urllib.parse
76
import zipfile
77
from pathlib import Path
78
from typing import Dict
79

80
import torch.utils.show_pickle
81

82

83
DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024
84

85
__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton',
86
           'burn_in_info', 'get_info_and_burn_skeleton']
87

88
def get_storage_info(storage):
89
    assert isinstance(storage, torch.utils.show_pickle.FakeObject)
90
    assert storage.module == "pers"
91
    assert storage.name == "obj"
92
    assert storage.state is None
93
    assert isinstance(storage.args, tuple)
94
    assert len(storage.args) == 1
95
    sa = storage.args[0]
96
    assert isinstance(sa, tuple)
97
    assert len(sa) == 5
98
    assert sa[0] == "storage"
99
    assert isinstance(sa[1], torch.utils.show_pickle.FakeClass)
100
    assert sa[1].module == "torch"
101
    assert sa[1].name.endswith("Storage")
102
    storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:])
103
    return storage_info
104

105

106
def hierarchical_pickle(data):
107
    if isinstance(data, (bool, int, float, str, type(None))):
108
        return data
109
    if isinstance(data, list):
110
        return [hierarchical_pickle(d) for d in data]
111
    if isinstance(data, tuple):
112
        return {
113
            "__tuple_values__": hierarchical_pickle(list(data)),
114
        }
115
    if isinstance(data, dict):
116
        return {
117
            "__is_dict__": True,
118
            "keys": hierarchical_pickle(list(data.keys())),
119
            "values": hierarchical_pickle(list(data.values())),
120
        }
121
    if isinstance(data, torch.utils.show_pickle.FakeObject):
122
        typename = f"{data.module}.{data.name}"
123
        if (
124
            typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
125
        ):
126
            assert data.args == ()
127
            return {
128
                "__module_type__": typename,
129
                "state": hierarchical_pickle(data.state),
130
            }
131
        if typename == "torch._utils._rebuild_tensor_v2":
132
            assert data.state is None
133
            if len(data.args) == 6:
134
                storage, offset, size, stride, requires_grad, hooks = data.args
135
            else:
136
                storage, offset, size, stride, requires_grad, hooks, metadata = data.args
137
            storage_info = get_storage_info(storage)
138
            return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]}
139
        if typename == "torch._utils._rebuild_qtensor":
140
            assert data.state is None
141
            storage, offset, size, stride, quantizer, requires_grad, hooks = data.args
142
            storage_info = get_storage_info(storage)
143
            assert isinstance(quantizer, tuple)
144
            assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass)
145
            assert quantizer[0].module == "torch"
146
            if quantizer[0].name == "per_tensor_affine":
147
                assert len(quantizer) == 3
148
                assert isinstance(quantizer[1], float)
149
                assert isinstance(quantizer[2], int)
150
                quantizer_extra = list(quantizer[1:3])
151
            else:
152
                quantizer_extra = []
153
            quantizer_json = [quantizer[0].name] + quantizer_extra
154
            return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]}
155
        if typename == "torch.jit._pickle.restore_type_tag":
156
            assert data.state is None
157
            obj, typ = data.args
158
            assert isinstance(typ, str)
159
            return hierarchical_pickle(obj)
160
        if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename):
161
            assert data.state is None
162
            ls, = data.args
163
            assert isinstance(ls, list)
164
            return hierarchical_pickle(ls)
165
        if typename == "torch.device":
166
            assert data.state is None
167
            name, = data.args
168
            assert isinstance(name, str)
169
            # Just forget that it was a device and return the name.
170
            return name
171
        if typename == "builtin.UnicodeDecodeError":
172
            assert data.state is None
173
            msg, = data.args
174
            assert isinstance(msg, str)
175
            # Hack: Pretend this is a module so we don't need custom serialization.
176
            # Hack: Wrap the message in a tuple so it looks like a nice state object.
177
            # TODO: Undo at least that second hack.  We should support string states.
178
            return {
179
                "__module_type__": typename,
180
                "state": hierarchical_pickle((msg,)),
181
            }
182
        raise Exception(f"Can't prepare fake object of type for JS: {typename}")  # noqa: TRY002
183
    raise Exception(f"Can't prepare data of type for JS: {type(data)}")  # noqa: TRY002
184

185

186
def get_model_info(
187
        path_or_file,
188
        title=None,
189
        extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT):
190
    """Get JSON-friendly information about a model.
191

192
    The result is suitable for being saved as model_info.json,
193
    or passed to burn_in_info.
194
    """
195

196
    if isinstance(path_or_file, os.PathLike):
197
        default_title = os.fspath(path_or_file)
198
        file_size = path_or_file.stat().st_size  # type: ignore[attr-defined]
199
    elif isinstance(path_or_file, str):
200
        default_title = path_or_file
201
        file_size = Path(path_or_file).stat().st_size
202
    else:
203
        default_title = "buffer"
204
        path_or_file.seek(0, io.SEEK_END)
205
        file_size = path_or_file.tell()
206
        path_or_file.seek(0)
207

208
    title = title or default_title
209

210
    with zipfile.ZipFile(path_or_file) as zf:
211
        path_prefix = None
212
        zip_files = []
213
        for zi in zf.infolist():
214
            prefix = re.sub("/.*", "", zi.filename)
215
            if path_prefix is None:
216
                path_prefix = prefix
217
            elif prefix != path_prefix:
218
                raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}")  # noqa: TRY002
219
            zip_files.append(dict(
220
                filename=zi.filename,
221
                compression=zi.compress_type,
222
                compressed_size=zi.compress_size,
223
                file_size=zi.file_size,
224
            ))
225

226
        assert path_prefix is not None
227
        version = zf.read(path_prefix + "/version").decode("utf-8").strip()
228

229
        def get_pickle(name):
230
            assert path_prefix is not None
231
            with zf.open(path_prefix + f"/{name}.pkl") as handle:
232
                raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
233
                return hierarchical_pickle(raw)
234

235
        model_data = get_pickle("data")
236
        constants = get_pickle("constants")
237

238
        # Intern strings that are likely to be re-used.
239
        # Pickle automatically detects shared structure,
240
        # so re-used strings are stored efficiently.
241
        # However, JSON has no way of representing this,
242
        # so we have to do it manually.
243
        interned_strings : Dict[str, int] = {}
244

245
        def ist(s):
246
            if s not in interned_strings:
247
                interned_strings[s] = len(interned_strings)
248
            return interned_strings[s]
249

250
        code_files = {}
251
        for zi in zf.infolist():
252
            if not zi.filename.endswith(".py"):
253
                continue
254
            with zf.open(zi) as handle:
255
                raw_code = handle.read()
256
            with zf.open(zi.filename + ".debug_pkl") as handle:
257
                raw_debug = handle.read()
258

259
            # Parse debug info and add begin/end markers if not present
260
            # to ensure that we cover the entire source code.
261
            debug_info_t = pickle.loads(raw_debug)
262
            text_table = None
263

264
            if (len(debug_info_t) == 3 and
265
                    isinstance(debug_info_t[0], str) and
266
                    debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
267
                _, text_table, content = debug_info_t
268

269
                def parse_new_format(line):
270
                    # (0, (('', '', 0), 0, 0))
271
                    num, ((text_indexes, fname_idx, offset), start, end), tag = line
272
                    text = ''.join(text_table[x] for x in text_indexes)  # type: ignore[index]
273
                    fname = text_table[fname_idx]  # type: ignore[index]
274
                    return num, ((text, fname, offset), start, end), tag
275

276
                debug_info_t = map(parse_new_format, content)
277

278
            debug_info = list(debug_info_t)
279
            if not debug_info:
280
                debug_info.append((0, (('', '', 0), 0, 0)))
281
            if debug_info[-1][0] != len(raw_code):
282
                debug_info.append((len(raw_code), (('', '', 0), 0, 0)))
283

284
            code_parts = []
285
            for di, di_next in zip(debug_info, debug_info[1:]):
286
                start, source_range, *_ = di
287
                end = di_next[0]
288
                assert end > start
289
                source, s_start, s_end = source_range
290
                s_text, s_file, s_line = source
291
                # TODO: Handle this case better.  TorchScript ranges are in bytes,
292
                # but JS doesn't really handle byte strings.
293
                # if bytes and chars are not equivalent for this string,
294
                # zero out the ranges so we don't highlight the wrong thing.
295
                if len(s_text) != len(s_text.encode("utf-8")):
296
                    s_start = 0
297
                    s_end = 0
298
                text = raw_code[start:end]
299
                code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end])
300
            code_files[zi.filename] = code_parts
301

302
        extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json")
303
        extra_files_jsons = {}
304
        for zi in zf.infolist():
305
            if not extra_files_json_pattern.fullmatch(zi.filename):
306
                continue
307
            if zi.file_size > extra_file_size_limit:
308
                continue
309
            with zf.open(zi) as handle:
310
                try:
311
                    json_content = json.load(handle)
312
                    extra_files_jsons[zi.filename] = json_content
313
                except json.JSONDecodeError:
314
                    extra_files_jsons[zi.filename] = "INVALID JSON"
315

316
        always_render_pickles = {
317
            "bytecode.pkl",
318
        }
319
        extra_pickles = {}
320
        for zi in zf.infolist():
321
            if not zi.filename.endswith(".pkl"):
322
                continue
323
            with zf.open(zi) as handle:
324
                # TODO: handle errors here and just ignore the file?
325
                # NOTE: For a lot of these files (like bytecode),
326
                # we could get away with just unpickling, but this should be safer.
327
                obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
328
            buf = io.StringIO()
329
            pprint.pprint(obj, buf)
330
            contents = buf.getvalue()
331
            # Checked the rendered length instead of the file size
332
            # because pickles with shared structure can explode in size during rendering.
333
            if os.path.basename(zi.filename) not in always_render_pickles and \
334
                    len(contents) > extra_file_size_limit:
335
                continue
336
            extra_pickles[zi.filename] = contents
337

338
    return {"model": dict(
339
        title=title,
340
        file_size=file_size,
341
        version=version,
342
        zip_files=zip_files,
343
        interned_strings=list(interned_strings),
344
        code_files=code_files,
345
        model_data=model_data,
346
        constants=constants,
347
        extra_files_jsons=extra_files_jsons,
348
        extra_pickles=extra_pickles,
349
    )}
350

351

352
def get_inline_skeleton():
353
    """Get a fully-inlined skeleton of the frontend.
354

355
    The returned HTML page has no external network dependencies for code.
356
    It can load model_info.json over HTTP, or be passed to burn_in_info.
357
    """
358

359
    import importlib.resources
360

361
    skeleton = importlib.resources.read_text(__package__, "skeleton.html")
362
    js_code = importlib.resources.read_text(__package__, "code.js")
363
    for js_module in ["preact", "htm"]:
364
        js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
365
        js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
366
        js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
367
    skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code)
368
    return skeleton
369

370

371
def burn_in_info(skeleton, info):
372
    """Burn model info into the HTML skeleton.
373

374
    The result will render the hard-coded model info and
375
    have no external network dependencies for code or data.
376
    """
377

378
    # Note that Python's json serializer does not escape slashes in strings.
379
    # Since we're inlining this JSON directly into a script tag, a string
380
    # containing "</script>" would end the script prematurely and
381
    # mess up our page.  Unconditionally escape fixes that.
382
    return skeleton.replace(
383
        "BURNED_IN_MODEL_INFO = null",
384
        "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/"))
385

386

387
def get_info_and_burn_skeleton(path_or_bytesio, **kwargs):
388
    model_info = get_model_info(path_or_bytesio, **kwargs)
389
    skeleton = get_inline_skeleton()
390
    page = burn_in_info(skeleton, model_info)
391
    return page
392

393

394
def main(argv, *, stdout=None):
395
    parser = argparse.ArgumentParser()
396
    parser.add_argument("--style", choices=["json", "html"])
397
    parser.add_argument("--title")
398
    parser.add_argument("model")
399
    args = parser.parse_args(argv[1:])
400

401
    info = get_model_info(args.model, title=args.title)
402

403
    output = stdout or sys.stdout
404

405
    if args.style == "json":
406
        output.write(json.dumps(info, sort_keys=True) + "\n")
407
    elif args.style == "html":
408
        skeleton = get_inline_skeleton()
409
        page = burn_in_info(skeleton, info)
410
        output.write(page)
411
    else:
412
        raise Exception("Invalid style")  # noqa: TRY002
413

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.