pytorch
412 строк · 16.4 Кб
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3"""
4model_dump: a one-stop shop for TorchScript model inspection.
5
6The goal of this tool is to provide a simple way to extract lots of
7useful information from a TorchScript model and make it easy for humans
8to consume. It (mostly) replaces zipinfo, common uses of show_pickle,
9and various ad-hoc analysis notebooks.
10
11The tool extracts information from the model and serializes it as JSON.
12That JSON can then be rendered by an HTML+JS page, either by
13loading the JSON over HTTP or producing a fully self-contained page
14with all of the code and data burned-in.
15"""
16
17# Maintainer notes follow.
18"""
19The implementation strategy has tension between 3 goals:
20- Small file size.
21- Fully self-contained.
22- Easy, modern JS environment.
23Using Preact and HTM achieves 1 and 2 with a decent result for 3.
24However, the models I tested with result in ~1MB JSON output,
25so even using something heavier like full React might be tolerable
26if the build process can be worked out.
27
28One principle I have followed that I think is very beneficial
29is to keep the JSON data as close as possible to the model
30and do most of the rendering logic on the client.
31This makes for easier development (just refresh, usually),
32allows for more laziness and dynamism, and lets us add more
33views of the same data without bloating the HTML file.
34
35Currently, this code doesn't actually load the model or even
36depend on any part of PyTorch. I don't know if that's an important
37feature to maintain, but it's probably worth preserving the ability
38to run at least basic analysis on models that cannot be loaded.
39
40I think the easiest way to develop this code is to cd into model_dump and
41run "python -m http.server", then load http://localhost:8000/skeleton.html
42in the browser. In another terminal, run
43"python -m torch.utils.model_dump --style=json FILE > \
44torch/utils/model_dump/model_info.json"
45every time you update the Python code or model.
46When you update JS, just refresh.
47
48Possible improvements:
49- Fix various TODO comments in this file and the JS.
50- Make the HTML much less janky, especially the auxiliary data panel.
51- Make the auxiliary data panel start small, expand when
52data is available, and have a button to clear/contract.
53- Clean up the JS. There's a lot of copypasta because
54I don't really know how to use Preact.
55- Make the HTML render and work nicely inside a Jupyter notebook.
56- Add the ability for JS to choose the URL to load the JSON based
57on the page URL (query or hash). That way we could publish the
58inlined skeleton once and have it load various JSON blobs.
59- Add a button to expand all expandable sections so ctrl-F works well.
60- Add hyperlinking from data to code, and code to code.
61- Add hyperlinking from debug info to Diffusion.
62- Make small tensor contents available.
63- Do something nice for quantized models
64(they probably don't work at all right now).
65"""
66
67import argparse
68import io
69import json
70import os
71import pickle
72import pprint
73import re
74import sys
75import urllib.parse
76import zipfile
77from pathlib import Path
78from typing import Dict
79
80import torch.utils.show_pickle
81
82
83DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024
84
85__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton',
86'burn_in_info', 'get_info_and_burn_skeleton']
87
88def get_storage_info(storage):
89assert isinstance(storage, torch.utils.show_pickle.FakeObject)
90assert storage.module == "pers"
91assert storage.name == "obj"
92assert storage.state is None
93assert isinstance(storage.args, tuple)
94assert len(storage.args) == 1
95sa = storage.args[0]
96assert isinstance(sa, tuple)
97assert len(sa) == 5
98assert sa[0] == "storage"
99assert isinstance(sa[1], torch.utils.show_pickle.FakeClass)
100assert sa[1].module == "torch"
101assert sa[1].name.endswith("Storage")
102storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:])
103return storage_info
104
105
106def hierarchical_pickle(data):
107if isinstance(data, (bool, int, float, str, type(None))):
108return data
109if isinstance(data, list):
110return [hierarchical_pickle(d) for d in data]
111if isinstance(data, tuple):
112return {
113"__tuple_values__": hierarchical_pickle(list(data)),
114}
115if isinstance(data, dict):
116return {
117"__is_dict__": True,
118"keys": hierarchical_pickle(list(data.keys())),
119"values": hierarchical_pickle(list(data.values())),
120}
121if isinstance(data, torch.utils.show_pickle.FakeObject):
122typename = f"{data.module}.{data.name}"
123if (
124typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
125):
126assert data.args == ()
127return {
128"__module_type__": typename,
129"state": hierarchical_pickle(data.state),
130}
131if typename == "torch._utils._rebuild_tensor_v2":
132assert data.state is None
133if len(data.args) == 6:
134storage, offset, size, stride, requires_grad, hooks = data.args
135else:
136storage, offset, size, stride, requires_grad, hooks, metadata = data.args
137storage_info = get_storage_info(storage)
138return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]}
139if typename == "torch._utils._rebuild_qtensor":
140assert data.state is None
141storage, offset, size, stride, quantizer, requires_grad, hooks = data.args
142storage_info = get_storage_info(storage)
143assert isinstance(quantizer, tuple)
144assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass)
145assert quantizer[0].module == "torch"
146if quantizer[0].name == "per_tensor_affine":
147assert len(quantizer) == 3
148assert isinstance(quantizer[1], float)
149assert isinstance(quantizer[2], int)
150quantizer_extra = list(quantizer[1:3])
151else:
152quantizer_extra = []
153quantizer_json = [quantizer[0].name] + quantizer_extra
154return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]}
155if typename == "torch.jit._pickle.restore_type_tag":
156assert data.state is None
157obj, typ = data.args
158assert isinstance(typ, str)
159return hierarchical_pickle(obj)
160if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename):
161assert data.state is None
162ls, = data.args
163assert isinstance(ls, list)
164return hierarchical_pickle(ls)
165if typename == "torch.device":
166assert data.state is None
167name, = data.args
168assert isinstance(name, str)
169# Just forget that it was a device and return the name.
170return name
171if typename == "builtin.UnicodeDecodeError":
172assert data.state is None
173msg, = data.args
174assert isinstance(msg, str)
175# Hack: Pretend this is a module so we don't need custom serialization.
176# Hack: Wrap the message in a tuple so it looks like a nice state object.
177# TODO: Undo at least that second hack. We should support string states.
178return {
179"__module_type__": typename,
180"state": hierarchical_pickle((msg,)),
181}
182raise Exception(f"Can't prepare fake object of type for JS: {typename}") # noqa: TRY002
183raise Exception(f"Can't prepare data of type for JS: {type(data)}") # noqa: TRY002
184
185
186def get_model_info(
187path_or_file,
188title=None,
189extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT):
190"""Get JSON-friendly information about a model.
191
192The result is suitable for being saved as model_info.json,
193or passed to burn_in_info.
194"""
195
196if isinstance(path_or_file, os.PathLike):
197default_title = os.fspath(path_or_file)
198file_size = path_or_file.stat().st_size # type: ignore[attr-defined]
199elif isinstance(path_or_file, str):
200default_title = path_or_file
201file_size = Path(path_or_file).stat().st_size
202else:
203default_title = "buffer"
204path_or_file.seek(0, io.SEEK_END)
205file_size = path_or_file.tell()
206path_or_file.seek(0)
207
208title = title or default_title
209
210with zipfile.ZipFile(path_or_file) as zf:
211path_prefix = None
212zip_files = []
213for zi in zf.infolist():
214prefix = re.sub("/.*", "", zi.filename)
215if path_prefix is None:
216path_prefix = prefix
217elif prefix != path_prefix:
218raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002
219zip_files.append(dict(
220filename=zi.filename,
221compression=zi.compress_type,
222compressed_size=zi.compress_size,
223file_size=zi.file_size,
224))
225
226assert path_prefix is not None
227version = zf.read(path_prefix + "/version").decode("utf-8").strip()
228
229def get_pickle(name):
230assert path_prefix is not None
231with zf.open(path_prefix + f"/{name}.pkl") as handle:
232raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
233return hierarchical_pickle(raw)
234
235model_data = get_pickle("data")
236constants = get_pickle("constants")
237
238# Intern strings that are likely to be re-used.
239# Pickle automatically detects shared structure,
240# so re-used strings are stored efficiently.
241# However, JSON has no way of representing this,
242# so we have to do it manually.
243interned_strings : Dict[str, int] = {}
244
245def ist(s):
246if s not in interned_strings:
247interned_strings[s] = len(interned_strings)
248return interned_strings[s]
249
250code_files = {}
251for zi in zf.infolist():
252if not zi.filename.endswith(".py"):
253continue
254with zf.open(zi) as handle:
255raw_code = handle.read()
256with zf.open(zi.filename + ".debug_pkl") as handle:
257raw_debug = handle.read()
258
259# Parse debug info and add begin/end markers if not present
260# to ensure that we cover the entire source code.
261debug_info_t = pickle.loads(raw_debug)
262text_table = None
263
264if (len(debug_info_t) == 3 and
265isinstance(debug_info_t[0], str) and
266debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
267_, text_table, content = debug_info_t
268
269def parse_new_format(line):
270# (0, (('', '', 0), 0, 0))
271num, ((text_indexes, fname_idx, offset), start, end), tag = line
272text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index]
273fname = text_table[fname_idx] # type: ignore[index]
274return num, ((text, fname, offset), start, end), tag
275
276debug_info_t = map(parse_new_format, content)
277
278debug_info = list(debug_info_t)
279if not debug_info:
280debug_info.append((0, (('', '', 0), 0, 0)))
281if debug_info[-1][0] != len(raw_code):
282debug_info.append((len(raw_code), (('', '', 0), 0, 0)))
283
284code_parts = []
285for di, di_next in zip(debug_info, debug_info[1:]):
286start, source_range, *_ = di
287end = di_next[0]
288assert end > start
289source, s_start, s_end = source_range
290s_text, s_file, s_line = source
291# TODO: Handle this case better. TorchScript ranges are in bytes,
292# but JS doesn't really handle byte strings.
293# if bytes and chars are not equivalent for this string,
294# zero out the ranges so we don't highlight the wrong thing.
295if len(s_text) != len(s_text.encode("utf-8")):
296s_start = 0
297s_end = 0
298text = raw_code[start:end]
299code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end])
300code_files[zi.filename] = code_parts
301
302extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json")
303extra_files_jsons = {}
304for zi in zf.infolist():
305if not extra_files_json_pattern.fullmatch(zi.filename):
306continue
307if zi.file_size > extra_file_size_limit:
308continue
309with zf.open(zi) as handle:
310try:
311json_content = json.load(handle)
312extra_files_jsons[zi.filename] = json_content
313except json.JSONDecodeError:
314extra_files_jsons[zi.filename] = "INVALID JSON"
315
316always_render_pickles = {
317"bytecode.pkl",
318}
319extra_pickles = {}
320for zi in zf.infolist():
321if not zi.filename.endswith(".pkl"):
322continue
323with zf.open(zi) as handle:
324# TODO: handle errors here and just ignore the file?
325# NOTE: For a lot of these files (like bytecode),
326# we could get away with just unpickling, but this should be safer.
327obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
328buf = io.StringIO()
329pprint.pprint(obj, buf)
330contents = buf.getvalue()
331# Checked the rendered length instead of the file size
332# because pickles with shared structure can explode in size during rendering.
333if os.path.basename(zi.filename) not in always_render_pickles and \
334len(contents) > extra_file_size_limit:
335continue
336extra_pickles[zi.filename] = contents
337
338return {"model": dict(
339title=title,
340file_size=file_size,
341version=version,
342zip_files=zip_files,
343interned_strings=list(interned_strings),
344code_files=code_files,
345model_data=model_data,
346constants=constants,
347extra_files_jsons=extra_files_jsons,
348extra_pickles=extra_pickles,
349)}
350
351
352def get_inline_skeleton():
353"""Get a fully-inlined skeleton of the frontend.
354
355The returned HTML page has no external network dependencies for code.
356It can load model_info.json over HTTP, or be passed to burn_in_info.
357"""
358
359import importlib.resources
360
361skeleton = importlib.resources.read_text(__package__, "skeleton.html")
362js_code = importlib.resources.read_text(__package__, "code.js")
363for js_module in ["preact", "htm"]:
364js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
365js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
366js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
367skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code)
368return skeleton
369
370
371def burn_in_info(skeleton, info):
372"""Burn model info into the HTML skeleton.
373
374The result will render the hard-coded model info and
375have no external network dependencies for code or data.
376"""
377
378# Note that Python's json serializer does not escape slashes in strings.
379# Since we're inlining this JSON directly into a script tag, a string
380# containing "</script>" would end the script prematurely and
381# mess up our page. Unconditionally escape fixes that.
382return skeleton.replace(
383"BURNED_IN_MODEL_INFO = null",
384"BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/"))
385
386
387def get_info_and_burn_skeleton(path_or_bytesio, **kwargs):
388model_info = get_model_info(path_or_bytesio, **kwargs)
389skeleton = get_inline_skeleton()
390page = burn_in_info(skeleton, model_info)
391return page
392
393
394def main(argv, *, stdout=None):
395parser = argparse.ArgumentParser()
396parser.add_argument("--style", choices=["json", "html"])
397parser.add_argument("--title")
398parser.add_argument("model")
399args = parser.parse_args(argv[1:])
400
401info = get_model_info(args.model, title=args.title)
402
403output = stdout or sys.stdout
404
405if args.style == "json":
406output.write(json.dumps(info, sort_keys=True) + "\n")
407elif args.style == "html":
408skeleton = get_inline_skeleton()
409page = burn_in_info(skeleton, info)
410output.write(page)
411else:
412raise Exception("Invalid style") # noqa: TRY002
413