pytorch
43 строки · 1004.0 Байт
1import argparse
2import sys
3import textwrap
4
5import pandas as pd
6
7
8def check_perf_csv(filename, threshold):
9"""
10Basic performance checking.
11"""
12
13df = pd.read_csv(filename)
14
15failed = []
16for _, row in df.iterrows():
17model_name = row["name"]
18speedup = row["speedup"]
19if speedup < threshold:
20failed.append(model_name)
21
22print(f"{model_name:34} {speedup}")
23
24if failed:
25print(
26textwrap.dedent(
27f"""
28Error {len(failed)} models performance regressed
29{' '.join(failed)}
30"""
31)
32)
33sys.exit(1)
34
35
36if __name__ == "__main__":
37parser = argparse.ArgumentParser()
38parser.add_argument("--file", "-f", type=str, help="csv file name")
39parser.add_argument(
40"--threshold", "-t", type=float, help="threshold speedup value to check against"
41)
42args = parser.parse_args()
43check_perf_csv(args.file, args.threshold)
44