pytorch
43 строки · 1004.0 Байт
1import argparse2import sys3import textwrap4
5import pandas as pd6
7
8def check_perf_csv(filename, threshold):9"""10Basic performance checking.
11"""
12
13df = pd.read_csv(filename)14
15failed = []16for _, row in df.iterrows():17model_name = row["name"]18speedup = row["speedup"]19if speedup < threshold:20failed.append(model_name)21
22print(f"{model_name:34} {speedup}")23
24if failed:25print(26textwrap.dedent(27f"""28Error {len(failed)} models performance regressed29{' '.join(failed)}30"""
31)32)33sys.exit(1)34
35
36if __name__ == "__main__":37parser = argparse.ArgumentParser()38parser.add_argument("--file", "-f", type=str, help="csv file name")39parser.add_argument(40"--threshold", "-t", type=float, help="threshold speedup value to check against"41)42args = parser.parse_args()43check_perf_csv(args.file, args.threshold)44