pytorch

Форк
0
181 строка · 5.9 Кб
1
"""
2
Update commited CSV files used as reference points by dynamo/inductor CI.
3

4
Currently only cares about graph breaks, so only saves those columns.
5

6
Hardcodes a list of job names and artifacts per job, but builds the lookup
7
by querying github sha and finding associated github actions workflow ID and CI jobs,
8
downloading artifact zips, extracting CSVs and filtering them.
9

10
Usage:
11

12
python benchmarks/dynamo/ci_expected_accuracy.py <sha of pytorch commit that has completed inductor benchmark jobs>
13

14
Known limitations:
15
- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one
16
"""
17

18
import argparse
19
import json
20
import os
21
import subprocess
22
import sys
23
import urllib
24
from io import BytesIO
25
from itertools import product
26
from pathlib import Path
27
from urllib.request import urlopen
28
from zipfile import ZipFile
29

30
import pandas as pd
31
import requests
32

33

34
# Note: the public query url targets this rockset lambda:
35
# https://console.rockset.com/lambdas/details/commons.artifacts
36
ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35"
37
CSV_LINTER = str(
38
    Path(__file__).absolute().parent.parent.parent.parent
39
    / "tools/linter/adapters/no_merge_conflict_csv_linter.py"
40
)
41

42

43
def query_job_sha(repo, sha):
44
    params = {
45
        "parameters": [
46
            {"name": "sha", "type": "string", "value": sha},
47
            {"name": "repo", "type": "string", "value": repo},
48
        ]
49
    }
50

51
    r = requests.post(url=ARTIFACTS_QUERY_URL, json=params)
52
    data = r.json()
53
    return data["results"]
54

55

56
def parse_job_name(job_str):
57
    return (part.strip() for part in job_str.split("/"))
58

59

60
def parse_test_str(test_str):
61
    return (part.strip() for part in test_str[6:].strip(")").split(","))
62

63

64
S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com"
65

66

67
def get_artifacts_urls(results, suites):
68
    urls = {}
69
    for r in results:
70
        if (
71
            r["workflowName"] in ("inductor", "inductor-periodic")
72
            and "test" in r["jobName"]
73
        ):
74
            config_str, test_str = parse_job_name(r["jobName"])
75
            suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str)
76
            workflowId = r["workflowId"]
77
            id = r["id"]
78
            runAttempt = r["runAttempt"]
79

80
            if suite in suites:
81
                artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip"
82
                s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}"
83
                urls[(suite, int(shard_id))] = s3_url
84
                print(f"{suite} {shard_id}, {num_shards}: {s3_url}")
85
    return urls
86

87

88
def normalize_suite_filename(suite_name):
89
    strs = suite_name.split("_")
90
    subsuite = strs[-1]
91
    if "timm" in subsuite:
92
        subsuite = subsuite.replace("timm", "timm_models")
93

94
    return subsuite
95

96

97
def download_artifacts_and_extract_csvs(urls):
98
    dataframes = {}
99
    for (suite, shard), url in urls.items():
100
        try:
101
            resp = urlopen(url)
102
            subsuite = normalize_suite_filename(suite)
103
            artifact = ZipFile(BytesIO(resp.read()))
104
            for phase in ("training", "inference"):
105
                name = f"test/test-reports/{phase}_{subsuite}.csv"
106
                try:
107
                    df = pd.read_csv(artifact.open(name))
108
                    df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int)
109
                    prev_df = dataframes.get((suite, phase), None)
110
                    dataframes[(suite, phase)] = (
111
                        pd.concat([prev_df, df]) if prev_df is not None else df
112
                    )
113
                except KeyError:
114
                    print(
115
                        f"Warning: Unable to find {name} in artifacts file from {url}, continuing"
116
                    )
117
        except urllib.error.HTTPError:
118
            print(f"Unable to download {url}, perhaps the CI job isn't finished?")
119

120
    return dataframes
121

122

123
def write_filtered_csvs(root_path, dataframes):
124
    for (suite, phase), df in dataframes.items():
125
        out_fn = os.path.join(root_path, f"{suite}_{phase}.csv")
126
        df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"])
127
        apply_lints(out_fn)
128

129

130
def apply_lints(filename):
131
    patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename]))
132
    if patch.get("replacement"):
133
        with open(filename) as fd:
134
            data = fd.read().replace(patch["original"], patch["replacement"])
135
        with open(filename, "w") as fd:
136
            fd.write(data)
137

138

139
if __name__ == "__main__":
140
    parser = argparse.ArgumentParser(
141
        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
142
    )
143

144
    parser.add_argument("sha")
145
    args = parser.parse_args()
146

147
    repo = "pytorch/pytorch"
148

149
    suites = {
150
        f"{a}_{b}"
151
        for a, b in product(
152
            [
153
                "aot_eager",
154
                "aot_inductor",
155
                "cpu_aot_inductor",
156
                "cpu_aot_inductor_amp_freezing",
157
                "cpu_aot_inductor_freezing",
158
                "cpu_inductor",
159
                "cpu_inductor_amp_freezing",
160
                "cpu_inductor_freezing",
161
                "dynamic_aot_eager",
162
                "dynamic_cpu_aot_inductor",
163
                "dynamic_cpu_aot_inductor_amp_freezing",
164
                "dynamic_cpu_aot_inductor_freezing",
165
                "dynamic_cpu_inductor",
166
                "dynamic_inductor",
167
                "dynamo_eager",
168
                "inductor",
169
            ],
170
            ["huggingface", "timm", "torchbench"],
171
        )
172
    }
173

174
    root_path = "benchmarks/dynamo/ci_expected_accuracy/"
175
    assert os.path.exists(root_path), f"cd <pytorch root> and ensure {root_path} exists"
176

177
    results = query_job_sha(repo, args.sha)
178
    urls = get_artifacts_urls(results, suites)
179
    dataframes = download_artifacts_and_extract_csvs(urls)
180
    write_filtered_csvs(root_path, dataframes)
181
    print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.")
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.