pytorch
172 строки · 5.6 Кб
1"""
2Update commited CSV files used as reference points by dynamo/inductor CI.
3
4Currently only cares about graph breaks, so only saves those columns.
5
6Hardcodes a list of job names and artifacts per job, but builds the lookup
7by querying github sha and finding associated github actions workflow ID and CI jobs,
8downloading artifact zips, extracting CSVs and filtering them.
9
10Usage:
11
12python benchmarks/dynamo/ci_expected_accuracy.py <sha of pytorch commit that has completed inductor benchmark jobs>
13
14Known limitations:
15- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one
16"""
17
18import argparse
19import json
20import os
21import pathlib
22import subprocess
23import sys
24import urllib
25from io import BytesIO
26from itertools import product
27from urllib.request import urlopen
28from zipfile import ZipFile
29
30import pandas as pd
31import requests
32
33# Note: the public query url targets this rockset lambda:
34# https://console.rockset.com/lambdas/details/commons.artifacts
35ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35"
36CSV_LINTER = str(
37pathlib.Path(__file__).absolute().parent.parent.parent.parent
38/ "tools/linter/adapters/no_merge_conflict_csv_linter.py"
39)
40
41
42def query_job_sha(repo, sha):
43params = {
44"parameters": [
45{"name": "sha", "type": "string", "value": sha},
46{"name": "repo", "type": "string", "value": repo},
47]
48}
49
50r = requests.post(url=ARTIFACTS_QUERY_URL, json=params)
51data = r.json()
52return data["results"]
53
54
55def parse_job_name(job_str):
56return (part.strip() for part in job_str.split("/"))
57
58
59def parse_test_str(test_str):
60return (part.strip() for part in test_str[6:].strip(")").split(","))
61
62
63S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com"
64
65
66def get_artifacts_urls(results, suites):
67urls = {}
68for r in results:
69if (
70r["workflowName"] in ("inductor", "inductor-periodic")
71and "test" in r["jobName"]
72):
73config_str, test_str = parse_job_name(r["jobName"])
74suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str)
75workflowId = r["workflowId"]
76id = r["id"]
77runAttempt = r["runAttempt"]
78
79if suite in suites:
80artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip"
81s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}"
82urls[(suite, int(shard_id))] = s3_url
83print(f"{suite} {shard_id}, {num_shards}: {s3_url}")
84return urls
85
86
87def normalize_suite_filename(suite_name):
88strs = suite_name.split("_")
89subsuite = strs[-1]
90if "timm" in subsuite:
91subsuite = subsuite.replace("timm", "timm_models")
92
93return subsuite
94
95
96def download_artifacts_and_extract_csvs(urls):
97dataframes = {}
98for (suite, shard), url in urls.items():
99try:
100resp = urlopen(url)
101subsuite = normalize_suite_filename(suite)
102artifact = ZipFile(BytesIO(resp.read()))
103for phase in ("training", "inference"):
104name = f"test/test-reports/{phase}_{subsuite}.csv"
105try:
106df = pd.read_csv(artifact.open(name))
107df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int)
108prev_df = dataframes.get((suite, phase), None)
109dataframes[(suite, phase)] = (
110pd.concat([prev_df, df]) if prev_df is not None else df
111)
112except KeyError:
113print(
114f"Warning: Unable to find {name} in artifacts file from {url}, continuing"
115)
116except urllib.error.HTTPError:
117print(f"Unable to download {url}, perhaps the CI job isn't finished?")
118
119return dataframes
120
121
122def write_filtered_csvs(root_path, dataframes):
123for (suite, phase), df in dataframes.items():
124out_fn = os.path.join(root_path, f"{suite}_{phase}.csv")
125df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"])
126apply_lints(out_fn)
127
128
129def apply_lints(filename):
130patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename]))
131if patch.get("replacement"):
132with open(filename) as fd:
133data = fd.read().replace(patch["original"], patch["replacement"])
134with open(filename, "w") as fd:
135fd.write(data)
136
137
138if __name__ == "__main__":
139parser = argparse.ArgumentParser(
140description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
141)
142
143parser.add_argument("sha")
144args = parser.parse_args()
145
146repo = "pytorch/pytorch"
147
148suites = {
149f"{a}_{b}"
150for a, b in product(
151[
152"aot_eager",
153"aot_inductor",
154"cpu_inductor",
155"dynamic_aot_eager",
156"dynamic_cpu_inductor",
157"dynamic_inductor",
158"dynamo_eager",
159"inductor",
160],
161["huggingface", "timm", "torchbench"],
162)
163}
164
165root_path = "benchmarks/dynamo/ci_expected_accuracy/"
166assert os.path.exists(root_path), f"cd <pytorch root> and ensure {root_path} exists"
167
168results = query_job_sha(repo, args.sha)
169urls = get_artifacts_urls(results, suites)
170dataframes = download_artifacts_and_extract_csvs(urls)
171write_filtered_csvs(root_path, dataframes)
172print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.")
173