pytorch

Форк
0
/
check_graph_breaks.py 
85 строк · 2.4 Кб
1
import argparse
2
import os
3
import sys
4
import textwrap
5

6
import pandas as pd
7

8

9
def get_field(csv, model_name: str, field: str):
10
    try:
11
        return csv.loc[csv["name"] == model_name][field].item()
12
    except Exception as e:
13
        return None
14

15

16
def check_graph_breaks(actual_csv, expected_csv, expected_filename):
17
    failed = []
18
    improved = []
19

20
    for model in actual_csv["name"]:
21
        graph_breaks = get_field(actual_csv, model, "graph_breaks")
22
        expected_graph_breaks = get_field(expected_csv, model, "graph_breaks")
23

24
        if graph_breaks == expected_graph_breaks:
25
            status = "PASS"
26
            print(f"{model:34}  {status}")
27
            continue
28

29
        elif graph_breaks > expected_graph_breaks:
30
            status = "FAIL:"
31
            failed.append(model)
32
        elif graph_breaks < expected_graph_breaks:
33
            status = "IMPROVED:"
34
            improved.append(model)
35
        print(
36
            f"{model:34}  {status:9} graph_breaks={graph_breaks}, expected={expected_graph_breaks}"
37
        )
38

39
    msg = ""
40
    if failed or improved:
41
        if failed:
42
            msg += textwrap.dedent(
43
                f"""
44
            Error: {len(failed)} models have new dynamo graph breaks:
45
                {' '.join(failed)}
46

47
            """
48
            )
49
        if improved:
50
            msg += textwrap.dedent(
51
                f"""
52
            Improvement: {len(improved)} models have fixed dynamo graph breaks:
53
                {' '.join(improved)}
54

55
            """
56
            )
57
        sha = os.getenv("SHA1", "{your CI commit sha}")
58
        msg += textwrap.dedent(
59
            f"""
60
        If this change is expected, you can update `{expected_filename}` to reflect the new baseline.
61
        from pytorch/pytorch root, run
62
        `python benchmarks/dynamo/ci_expected_accuracy/update_expected.py {sha}`
63
        and then `git add` the resulting local changes to expected CSVs to your commit.
64
        """
65
        )
66
    return failed or improved, msg
67

68

69
def main():
70
    parser = argparse.ArgumentParser()
71
    parser.add_argument("--actual", type=str, required=True)
72
    parser.add_argument("--expected", type=str, required=True)
73
    args = parser.parse_args()
74

75
    actual = pd.read_csv(args.actual)
76
    expected = pd.read_csv(args.expected)
77

78
    failed, msg = check_graph_breaks(actual, expected, args.expected)
79
    if failed:
80
        print(msg)
81
        sys.exit(1)
82

83

84
if __name__ == "__main__":
85
    main()
86

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.