pytorch
181 строка · 5.9 Кб
1"""
2Update commited CSV files used as reference points by dynamo/inductor CI.
3
4Currently only cares about graph breaks, so only saves those columns.
5
6Hardcodes a list of job names and artifacts per job, but builds the lookup
7by querying github sha and finding associated github actions workflow ID and CI jobs,
8downloading artifact zips, extracting CSVs and filtering them.
9
10Usage:
11
12python benchmarks/dynamo/ci_expected_accuracy.py <sha of pytorch commit that has completed inductor benchmark jobs>
13
14Known limitations:
15- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one
16"""
17
18import argparse19import json20import os21import subprocess22import sys23import urllib24from io import BytesIO25from itertools import product26from pathlib import Path27from urllib.request import urlopen28from zipfile import ZipFile29
30import pandas as pd31import requests32
33
34# Note: the public query url targets this rockset lambda:
35# https://console.rockset.com/lambdas/details/commons.artifacts
36ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35"37CSV_LINTER = str(38Path(__file__).absolute().parent.parent.parent.parent39/ "tools/linter/adapters/no_merge_conflict_csv_linter.py"40)
41
42
43def query_job_sha(repo, sha):44params = {45"parameters": [46{"name": "sha", "type": "string", "value": sha},47{"name": "repo", "type": "string", "value": repo},48]49}50
51r = requests.post(url=ARTIFACTS_QUERY_URL, json=params)52data = r.json()53return data["results"]54
55
56def parse_job_name(job_str):57return (part.strip() for part in job_str.split("/"))58
59
60def parse_test_str(test_str):61return (part.strip() for part in test_str[6:].strip(")").split(","))62
63
64S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com"65
66
67def get_artifacts_urls(results, suites):68urls = {}69for r in results:70if (71r["workflowName"] in ("inductor", "inductor-periodic")72and "test" in r["jobName"]73):74config_str, test_str = parse_job_name(r["jobName"])75suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str)76workflowId = r["workflowId"]77id = r["id"]78runAttempt = r["runAttempt"]79
80if suite in suites:81artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip"82s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}"83urls[(suite, int(shard_id))] = s3_url84print(f"{suite} {shard_id}, {num_shards}: {s3_url}")85return urls86
87
88def normalize_suite_filename(suite_name):89strs = suite_name.split("_")90subsuite = strs[-1]91if "timm" in subsuite:92subsuite = subsuite.replace("timm", "timm_models")93
94return subsuite95
96
97def download_artifacts_and_extract_csvs(urls):98dataframes = {}99for (suite, shard), url in urls.items():100try:101resp = urlopen(url)102subsuite = normalize_suite_filename(suite)103artifact = ZipFile(BytesIO(resp.read()))104for phase in ("training", "inference"):105name = f"test/test-reports/{phase}_{subsuite}.csv"106try:107df = pd.read_csv(artifact.open(name))108df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int)109prev_df = dataframes.get((suite, phase), None)110dataframes[(suite, phase)] = (111pd.concat([prev_df, df]) if prev_df is not None else df112)113except KeyError:114print(115f"Warning: Unable to find {name} in artifacts file from {url}, continuing"116)117except urllib.error.HTTPError:118print(f"Unable to download {url}, perhaps the CI job isn't finished?")119
120return dataframes121
122
123def write_filtered_csvs(root_path, dataframes):124for (suite, phase), df in dataframes.items():125out_fn = os.path.join(root_path, f"{suite}_{phase}.csv")126df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"])127apply_lints(out_fn)128
129
130def apply_lints(filename):131patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename]))132if patch.get("replacement"):133with open(filename) as fd:134data = fd.read().replace(patch["original"], patch["replacement"])135with open(filename, "w") as fd:136fd.write(data)137
138
139if __name__ == "__main__":140parser = argparse.ArgumentParser(141description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter142)143
144parser.add_argument("sha")145args = parser.parse_args()146
147repo = "pytorch/pytorch"148
149suites = {150f"{a}_{b}"151for a, b in product(152[153"aot_eager",154"aot_inductor",155"cpu_aot_inductor",156"cpu_aot_inductor_amp_freezing",157"cpu_aot_inductor_freezing",158"cpu_inductor",159"cpu_inductor_amp_freezing",160"cpu_inductor_freezing",161"dynamic_aot_eager",162"dynamic_cpu_aot_inductor",163"dynamic_cpu_aot_inductor_amp_freezing",164"dynamic_cpu_aot_inductor_freezing",165"dynamic_cpu_inductor",166"dynamic_inductor",167"dynamo_eager",168"inductor",169],170["huggingface", "timm", "torchbench"],171)172}173
174root_path = "benchmarks/dynamo/ci_expected_accuracy/"175assert os.path.exists(root_path), f"cd <pytorch root> and ensure {root_path} exists"176
177results = query_job_sha(repo, args.sha)178urls = get_artifacts_urls(results, suites)179dataframes = download_artifacts_and_extract_csvs(urls)180write_filtered_csvs(root_path, dataframes)181print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.")182