pytorch
86 строк · 2.5 Кб
1#!/usr/bin/env python3
2#
3# Computes difference between measurements produced by ./benchmark.py.
4#
5
6import argparse
7import json
8
9import numpy as np
10
11
12def load(path):
13with open(path) as f:
14return json.load(f)
15
16
17def main():
18parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff")
19parser.add_argument("file", nargs=2)
20args = parser.parse_args()
21
22if len(args.file) != 2:
23raise RuntimeError("Must specify 2 files to diff")
24
25ja = load(args.file[0])
26jb = load(args.file[1])
27
28keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
29print(f"{'':20s} {'baseline':>20s} {'test':>20s}")
30print(f"{'':20s} {'-' * 20:>20s} {'-' * 20:>20s}")
31for key in sorted(keys):
32va = str(ja.get(key, "-"))
33vb = str(jb.get(key, "-"))
34print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}")
35print()
36
37ba = ja["benchmark_results"]
38bb = jb["benchmark_results"]
39for ra, rb in zip(ba, bb):
40if ra["model"] != rb["model"]:
41continue
42if ra["batch_size"] != rb["batch_size"]:
43continue
44
45model = ra["model"]
46batch_size = int(ra["batch_size"])
47name = f"{model} with batch size {batch_size}"
48print(f"Benchmark: {name}")
49
50# Print header
51print()
52print(f"{'':>10s}", end="") # noqa: E999
53for _ in [75, 95]:
54print(
55f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
56) # noqa: E999
57print()
58
59# Print measurements
60for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
61# Ignore round without ddp
62if i == 0:
63continue
64# Sanity check: ignore if number of ranks is not equal
65if len(xa["ranks"]) != len(xb["ranks"]):
66continue
67
68ngpus = len(xa["ranks"])
69ma = sorted(xa["measurements"])
70mb = sorted(xb["measurements"])
71print(f"{ngpus:>4d} GPUs:", end="") # noqa: E999
72for p in [75, 95]:
73va = np.percentile(ma, p)
74vb = np.percentile(mb, p)
75# We're measuring time, so lower is better (hence the negation)
76delta = -100 * ((vb - va) / va)
77print(
78f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
79end="",
80) # noqa: E999
81print()
82print()
83
84
85if __name__ == "__main__":
86main()
87