pytorch
144 строки · 4.4 Кб
1import logging2import os3import re4from collections import defaultdict5
6import click7import pandas as pd8from tabulate import tabulate9
10
11def gmean(s):12return s.product() ** (1 / len(s))13
14
15def find_csv_files(path, perf_compare):16"""17Recursively search for all CSV files in directory and subdirectories whose
18name contains a target string.
19"""
20
21def is_csv(f):22if perf_compare:23regex = r"training_(torchbench|huggingface|timm_models)\.csv"24return re.match(regex, f) is not None25else:26return f.endswith("_performance.csv")27
28csv_files = []29for root, dirs, files in os.walk(path):30for file in files:31if is_csv(file):32csv_files.append(os.path.join(root, file))33return csv_files34
35
36@click.command()37@click.argument("directory", default="artifacts")38@click.option("--amp", is_flag=True)39@click.option("--float32", is_flag=True)40@click.option(41"--perf-compare",42is_flag=True,43help="Set if the CSVs were generated by running manually the action rather than picking them from the nightly job",44)
45def main(directory, amp, float32, perf_compare):46"""47Given a directory containing multiple CSVs from --performance benchmark
48runs, aggregates and generates summary statistics similar to the web UI at
49https://torchci-git-fork-huydhn-add-compilers-bench-74abf8-fbopensource.vercel.app/benchmark/compilers
50
51This is most useful if you've downloaded CSVs from CI and need to quickly
52look at aggregate stats. The CSVs are expected to follow exactly the same
53naming convention that is used in CI.
54
55You may also be interested in
56https://docs.google.com/document/d/1DQQxIgmKa3eF0HByDTLlcJdvefC4GwtsklJUgLs09fQ/edit#
57which explains how to interpret the raw csv data.
58"""
59dtypes = ["amp", "float32"]60if amp and not float32:61dtypes = ["amp"]62if float32 and not amp:63dtypes = ["float32"]64
65dfs = defaultdict(list)66for f in find_csv_files(directory, perf_compare):67try:68dfs[os.path.basename(f)].append(pd.read_csv(f))69except Exception:70logging.warning("failed parsing %s", f)71raise72
73# dtype -> statistic -> benchmark -> compiler -> value74results = defaultdict( # dtype75lambda: defaultdict( # statistic76lambda: defaultdict(dict) # benchmark # compiler -> value77)78)79
80for k, v in sorted(dfs.items()):81if perf_compare:82regex = r"training_(torchbench|huggingface|timm_models)\.csv"83m = re.match(regex, k)84assert m is not None, k85compiler = "inductor"86benchmark = m.group(1)87dtype = "float32"88mode = "training"89device = "cuda"90else:91regex = (92"(.+)_"93"(torchbench|huggingface|timm_models)_"94"(float32|amp)_"95"(inference|training)_"96"(cpu|cuda)_"97r"performance\.csv"98)99m = re.match(regex, k)100compiler = m.group(1)101benchmark = m.group(2)102dtype = m.group(3)103mode = m.group(4)104device = m.group(5)105
106df = pd.concat(v)107df = df.dropna().query("speedup != 0")108
109statistics = {110"speedup": gmean(df["speedup"]),111"comptime": df["compilation_latency"].mean(),112"memory": gmean(df["compression_ratio"]),113}114
115if dtype not in dtypes:116continue117
118for statistic, v in statistics.items():119results[f"{device} {dtype} {mode}"][statistic][benchmark][compiler] = v120
121descriptions = {122"speedup": "Geometric mean speedup",123"comptime": "Mean compilation time",124"memory": "Peak memory compression ratio",125}126
127for dtype_mode, r in results.items():128print(f"# {dtype_mode} performance results")129for statistic, data in r.items():130print(f"## {descriptions[statistic]}")131
132table = []133for row_name in data[next(iter(data.keys()))]:134row = [row_name]135for col_name in data:136row.append(round(data[col_name][row_name], 2))137table.append(row)138
139headers = list(data.keys())140print(tabulate(table, headers=headers))141print()142
143
144main()145