4
import xml.etree.ElementTree as ET
5
from pathlib import Path
6
from tempfile import TemporaryDirectory
7
from typing import Any, Dict, Generator, Tuple
9
from tools.stats.upload_stats_lib import (
10
download_gha_artifacts,
11
download_s3_artifacts,
12
is_rerun_disabled_tests,
14
upload_workflow_stats_to_s3,
16
from tools.stats.upload_test_stats import process_xml_element
18
TESTCASE_TAG = "testcase"
24
) -> Dict[str, Dict[str, int]]:
26
Return a list of disabled tests that should be re-enabled and those that are still
27
flaky (failed or skipped)
29
root = ET.parse(report)
31
# All rerun tests from a report are grouped here:
33
# * Success test should be re-enable if it's green after rerunning in all platforms
34
# where it is currently disabled
35
# * Failures from pytest because pytest-flakefinder is used to run the same test
36
# multiple times, some could fails
37
# * Skipped tests from unittest
39
# We want to keep track of how many times the test fails (num_red) or passes (num_green)
40
all_tests: Dict[str, Dict[str, int]] = {}
42
for test_case in root.iter(TESTCASE_TAG):
43
parsed_test_case = process_xml_element(test_case)
45
# Under --rerun-disabled-tests mode, a test is skipped when:
46
# * it's skipped explicitly inside PyTorch code
47
# * it's skipped because it's a normal enabled test
48
# * or it's falky (num_red > 0 and num_green > 0)
49
# * or it's failing (num_red > 0 and num_green == 0)
51
# We care only about the latter two here
52
skipped = parsed_test_case.get("skipped", None)
54
# NB: Regular ONNX tests could return a list of subskips here where each item in the
55
# list is a skipped message. In the context of rerunning disabled tests, we could
56
# ignore this case as returning a list of subskips only happens when tests are run
59
type(skipped) is list or "num_red" not in skipped.get("message", "")
63
name = parsed_test_case.get("name", "")
64
classname = parsed_test_case.get("classname", "")
65
filename = parsed_test_case.get("file", "")
67
if not name or not classname or not filename:
70
# Check if the test is a failure
71
failure = parsed_test_case.get("failure", None)
73
disabled_test_id = SEPARATOR.join([name, classname, filename])
74
if disabled_test_id not in all_tests:
75
all_tests[disabled_test_id] = {
80
# Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's
81
# counted as a success. Otherwise, it's still flaky or failing
84
stats = json.loads(skipped.get("message", ""))
85
except json.JSONDecodeError:
88
all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0)
89
all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0)
91
# As a failure, increase the failure count
92
all_tests[disabled_test_id]["num_red"] += 1
94
all_tests[disabled_test_id]["num_green"] += 1
100
repo: str, workflow_run_id: int, workflow_run_attempt: int
101
) -> Generator[Path, None, None]:
103
Gather all the test reports from S3 and GHA. It is currently not possible to guess which
104
test reports are from rerun_disabled_tests workflow because the name doesn't include the
105
test config. So, all reports will need to be downloaded and examined
107
with TemporaryDirectory() as temp_dir:
108
print("Using temporary directory:", temp_dir)
111
artifact_paths = download_s3_artifacts(
112
"test-reports", workflow_run_id, workflow_run_attempt
114
for path in artifact_paths:
117
artifact_paths = download_gha_artifacts(
118
"test-report", workflow_run_id, workflow_run_attempt
120
for path in artifact_paths:
123
yield from Path(".").glob("**/*.xml")
126
def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]:
128
Follow flaky bot convention here, if that changes, this will also need to be updated
130
name, classname, filename = test_id.split(SEPARATOR)
131
return f"{name} (__main__.{classname})", name, classname, filename
136
workflow_run_attempt: int,
143
) -> Tuple[Any, Dict[str, Any]]:
145
Prepare the record to save onto S3
149
workflow_run_attempt,
156
"workflow_id": workflow_id,
157
"workflow_run_attempt": workflow_run_attempt,
159
"classname": classname,
160
"filename": filename,
162
"num_green": num_green,
171
workflow_run_attempt: int,
172
all_tests: Dict[str, Dict[str, int]],
175
Save the result to S3, so it can go to Rockset
177
should_be_enabled_tests = {
179
for name, stats in all_tests.items()
180
if "num_green" in stats
181
and stats["num_green"]
182
and "num_red" in stats
183
and stats["num_red"] == 0
185
still_flaky_tests = {
187
for name, stats in all_tests.items()
188
if name not in should_be_enabled_tests
192
for test_id, stats in all_tests.items():
193
num_green = stats.get("num_green", 0)
194
num_red = stats.get("num_red", 0)
195
disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
197
key, record = prepare_record(
198
workflow_id=workflow_id,
199
workflow_run_attempt=workflow_run_attempt,
203
flaky=test_id in still_flaky_tests,
207
records[key] = record
210
print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:")
211
for test_id, stats in should_be_enabled_tests.items():
212
disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
213
print(f" {disabled_test_name} from {filename}")
215
print(f"The following {len(still_flaky_tests)} are still flaky:")
216
for test_id, stats in still_flaky_tests.items():
217
num_green = stats.get("num_green", 0)
218
num_red = stats.get("num_red", 0)
220
disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
222
f" {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}"
225
upload_workflow_stats_to_s3(
227
workflow_run_attempt,
228
"rerun_disabled_tests",
229
list(records.values()),
233
def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
235
Find the list of all disabled tests that should be re-enabled
237
# Aggregated across all jobs
238
all_tests: Dict[str, Dict[str, int]] = {}
240
for report in get_test_reports(
241
args.repo, args.workflow_run_id, args.workflow_run_attempt
243
tests = process_report(report)
245
# The scheduled workflow has both rerun disabled tests and memory leak check jobs.
246
# We are only interested in the former here
247
if not is_rerun_disabled_tests(tests):
250
for name, stats in tests.items():
251
if name not in all_tests:
252
all_tests[name] = stats.copy()
254
all_tests[name]["num_green"] += stats.get("num_green", 0)
255
all_tests[name]["num_red"] += stats.get("num_red", 0)
259
workflow_run_attempt,
264
if __name__ == "__main__":
265
parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3")
270
help="id of the workflow to get artifacts from",
273
"--workflow-run-attempt",
276
help="which retry of the workflow this is",
282
help="which GitHub repo this workflow run belongs to",
285
args = parser.parse_args()
286
main(args.repo, args.workflow_run_id, args.workflow_run_attempt)