pytorch

Форк
0
172 строки · 5.6 Кб
1
"""
2
Update commited CSV files used as reference points by dynamo/inductor CI.
3

4
Currently only cares about graph breaks, so only saves those columns.
5

6
Hardcodes a list of job names and artifacts per job, but builds the lookup
7
by querying github sha and finding associated github actions workflow ID and CI jobs,
8
downloading artifact zips, extracting CSVs and filtering them.
9

10
Usage:
11

12
python benchmarks/dynamo/ci_expected_accuracy.py <sha of pytorch commit that has completed inductor benchmark jobs>
13

14
Known limitations:
15
- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one
16
"""
17

18
import argparse
19
import json
20
import os
21
import pathlib
22
import subprocess
23
import sys
24
import urllib
25
from io import BytesIO
26
from itertools import product
27
from urllib.request import urlopen
28
from zipfile import ZipFile
29

30
import pandas as pd
31
import requests
32

33
# Note: the public query url targets this rockset lambda:
34
# https://console.rockset.com/lambdas/details/commons.artifacts
35
ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35"
36
CSV_LINTER = str(
37
    pathlib.Path(__file__).absolute().parent.parent.parent.parent
38
    / "tools/linter/adapters/no_merge_conflict_csv_linter.py"
39
)
40

41

42
def query_job_sha(repo, sha):
43
    params = {
44
        "parameters": [
45
            {"name": "sha", "type": "string", "value": sha},
46
            {"name": "repo", "type": "string", "value": repo},
47
        ]
48
    }
49

50
    r = requests.post(url=ARTIFACTS_QUERY_URL, json=params)
51
    data = r.json()
52
    return data["results"]
53

54

55
def parse_job_name(job_str):
56
    return (part.strip() for part in job_str.split("/"))
57

58

59
def parse_test_str(test_str):
60
    return (part.strip() for part in test_str[6:].strip(")").split(","))
61

62

63
S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com"
64

65

66
def get_artifacts_urls(results, suites):
67
    urls = {}
68
    for r in results:
69
        if (
70
            r["workflowName"] in ("inductor", "inductor-periodic")
71
            and "test" in r["jobName"]
72
        ):
73
            config_str, test_str = parse_job_name(r["jobName"])
74
            suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str)
75
            workflowId = r["workflowId"]
76
            id = r["id"]
77
            runAttempt = r["runAttempt"]
78

79
            if suite in suites:
80
                artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip"
81
                s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}"
82
                urls[(suite, int(shard_id))] = s3_url
83
                print(f"{suite} {shard_id}, {num_shards}: {s3_url}")
84
    return urls
85

86

87
def normalize_suite_filename(suite_name):
88
    strs = suite_name.split("_")
89
    subsuite = strs[-1]
90
    if "timm" in subsuite:
91
        subsuite = subsuite.replace("timm", "timm_models")
92

93
    return subsuite
94

95

96
def download_artifacts_and_extract_csvs(urls):
97
    dataframes = {}
98
    for (suite, shard), url in urls.items():
99
        try:
100
            resp = urlopen(url)
101
            subsuite = normalize_suite_filename(suite)
102
            artifact = ZipFile(BytesIO(resp.read()))
103
            for phase in ("training", "inference"):
104
                name = f"test/test-reports/{phase}_{subsuite}.csv"
105
                try:
106
                    df = pd.read_csv(artifact.open(name))
107
                    df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int)
108
                    prev_df = dataframes.get((suite, phase), None)
109
                    dataframes[(suite, phase)] = (
110
                        pd.concat([prev_df, df]) if prev_df is not None else df
111
                    )
112
                except KeyError:
113
                    print(
114
                        f"Warning: Unable to find {name} in artifacts file from {url}, continuing"
115
                    )
116
        except urllib.error.HTTPError:
117
            print(f"Unable to download {url}, perhaps the CI job isn't finished?")
118

119
    return dataframes
120

121

122
def write_filtered_csvs(root_path, dataframes):
123
    for (suite, phase), df in dataframes.items():
124
        out_fn = os.path.join(root_path, f"{suite}_{phase}.csv")
125
        df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"])
126
        apply_lints(out_fn)
127

128

129
def apply_lints(filename):
130
    patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename]))
131
    if patch.get("replacement"):
132
        with open(filename) as fd:
133
            data = fd.read().replace(patch["original"], patch["replacement"])
134
        with open(filename, "w") as fd:
135
            fd.write(data)
136

137

138
if __name__ == "__main__":
139
    parser = argparse.ArgumentParser(
140
        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
141
    )
142

143
    parser.add_argument("sha")
144
    args = parser.parse_args()
145

146
    repo = "pytorch/pytorch"
147

148
    suites = {
149
        f"{a}_{b}"
150
        for a, b in product(
151
            [
152
                "aot_eager",
153
                "aot_inductor",
154
                "cpu_inductor",
155
                "dynamic_aot_eager",
156
                "dynamic_cpu_inductor",
157
                "dynamic_inductor",
158
                "dynamo_eager",
159
                "inductor",
160
            ],
161
            ["huggingface", "timm", "torchbench"],
162
        )
163
    }
164

165
    root_path = "benchmarks/dynamo/ci_expected_accuracy/"
166
    assert os.path.exists(root_path), f"cd <pytorch root> and ensure {root_path} exists"
167

168
    results = query_job_sha(repo, args.sha)
169
    urls = get_artifacts_urls(results, suites)
170
    dataframes = download_artifacts_and_extract_csvs(urls)
171
    write_filtered_csvs(root_path, dataframes)
172
    print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.")
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.