pytorch

Форк
0
86 строк · 2.5 Кб
1
#!/usr/bin/env python3
2
#
3
# Computes difference between measurements produced by ./benchmark.py.
4
#
5

6
import argparse
7
import json
8

9
import numpy as np
10

11

12
def load(path):
13
    with open(path) as f:
14
        return json.load(f)
15

16

17
def main():
18
    parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff")
19
    parser.add_argument("file", nargs=2)
20
    args = parser.parse_args()
21

22
    if len(args.file) != 2:
23
        raise RuntimeError("Must specify 2 files to diff")
24

25
    ja = load(args.file[0])
26
    jb = load(args.file[1])
27

28
    keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
29
    print(f"{'':20s} {'baseline':>20s}      {'test':>20s}")
30
    print(f"{'':20s} {'-' * 20:>20s}      {'-' * 20:>20s}")
31
    for key in sorted(keys):
32
        va = str(ja.get(key, "-"))
33
        vb = str(jb.get(key, "-"))
34
        print(f"{key + ':':20s} {va:>20s}  vs  {vb:>20s}")
35
    print()
36

37
    ba = ja["benchmark_results"]
38
    bb = jb["benchmark_results"]
39
    for ra, rb in zip(ba, bb):
40
        if ra["model"] != rb["model"]:
41
            continue
42
        if ra["batch_size"] != rb["batch_size"]:
43
            continue
44

45
        model = ra["model"]
46
        batch_size = int(ra["batch_size"])
47
        name = f"{model} with batch size {batch_size}"
48
        print(f"Benchmark: {name}")
49

50
        # Print header
51
        print()
52
        print(f"{'':>10s}", end="")  # noqa: E999
53
        for _ in [75, 95]:
54
            print(
55
                f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
56
            )  # noqa: E999
57
        print()
58

59
        # Print measurements
60
        for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
61
            # Ignore round without ddp
62
            if i == 0:
63
                continue
64
            # Sanity check: ignore if number of ranks is not equal
65
            if len(xa["ranks"]) != len(xb["ranks"]):
66
                continue
67

68
            ngpus = len(xa["ranks"])
69
            ma = sorted(xa["measurements"])
70
            mb = sorted(xb["measurements"])
71
            print(f"{ngpus:>4d} GPUs:", end="")  # noqa: E999
72
            for p in [75, 95]:
73
                va = np.percentile(ma, p)
74
                vb = np.percentile(mb, p)
75
                # We're measuring time, so lower is better (hence the negation)
76
                delta = -100 * ((vb - va) / va)
77
                print(
78
                    f"  p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
79
                    end="",
80
                )  # noqa: E999
81
            print()
82
        print()
83

84

85
if __name__ == "__main__":
86
    main()
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.