pytorch
57 строк · 1.6 Кб
1import argparse
2import sys
3import textwrap
4
5import pandas as pd
6
7
8def main(args):
9actual = pd.read_csv(args.actual)
10expected = pd.read_csv(args.expected)
11failed = []
12
13for name in actual["name"]:
14actual_memory_compression = float(
15actual.loc[actual["name"] == name]["compression_ratio"]
16)
17try:
18expected_memory_compression = float(
19expected.loc[expected["name"] == name]["compression_ratio"]
20)
21except TypeError:
22print(f"{name:34} is missing from {args.expected}")
23continue
24if actual_memory_compression >= expected_memory_compression * 0.95:
25status = "PASS"
26else:
27status = "FAIL"
28failed.append(name)
29print(
30f"""
31{name:34}:
32actual_memory_compression={actual_memory_compression:.2f},
33expected_memory_compression={expected_memory_compression:.2f},
34{status}
35"""
36)
37
38if failed:
39print(
40textwrap.dedent(
41f"""
42Error: {len(failed)} models below expected memory compression ratio:
43{' '.join(failed)}
44If this drop is expected, you can update `{args.expected}`.
45"""
46)
47)
48sys.exit(1)
49
50
51parser = argparse.ArgumentParser()
52parser.add_argument("--actual", type=str, required=True)
53parser.add_argument("--expected", type=str, required=True)
54args = parser.parse_args()
55
56if __name__ == "__main__":
57main(args)
58