pytorch
225 строк · 7.4 Кб
1#!/usr/bin/env python3
2import argparse3import os4import subprocess5from pathlib import Path6
7from common import (8get_testcases,9is_failure,10is_passing_skipped_test,11is_unexpected_success,12key,13open_test_results,14)
15from download_reports import download_reports16
17
18"""
19Usage: update_failures.py /path/to/dynamo_test_failures.py /path/to/test commit_sha
20
21Best-effort updates the xfail and skip files under test directory
22by parsing test reports.
23
24You'll need to provide the commit_sha for the latest commit on a PR
25from which we will pull CI test results.
26
27Instructions:
28- On your PR, add the "keep-going" label to ensure that all the tests are
29failing (as opposed to CI stopping on the first failure). You may need to
30restart your test jobs by force-pushing to your branch for CI to pick
31up the "keep-going" label.
32- Wait for all the tests to finish running.
33- Find the full SHA of your commit and run this command.
34
35This script requires the `gh` cli. You'll need to install it and then
36authenticate with it via `gh auth login` before using this script.
37https://docs.github.com/en/github-cli/github-cli/quickstart
38"""
39
40
41def patch_file(42filename, test_dir, unexpected_successes, new_xfails, new_skips, unexpected_skips43):44failures_directory = os.path.join(test_dir, "dynamo_expected_failures")45skips_directory = os.path.join(test_dir, "dynamo_skips")46
47dynamo_expected_failures = set(os.listdir(failures_directory))48dynamo_skips = set(os.listdir(skips_directory))49
50# These are hand written skips51extra_dynamo_skips = set()52with open(filename) as f:53start = False54for text in f.readlines():55text = text.strip()56if start:57if text == "}":58break59extra_dynamo_skips.add(text.strip(',"'))60else:61if text == "extra_dynamo_skips = {":62start = True63
64def format(testcase):65classname = testcase.attrib["classname"]66name = testcase.attrib["name"]67return f"{classname}.{name}"68
69formatted_unexpected_successes = {70f"{format(test)}" for test in unexpected_successes.values()71}72formatted_unexpected_skips = {73f"{format(test)}" for test in unexpected_skips.values()74}75formatted_new_xfails = [f"{format(test)}" for test in new_xfails.values()]76formatted_new_skips = [f"{format(test)}" for test in new_skips.values()]77
78def remove_file(path, name):79file = os.path.join(path, name)80cmd = ["git", "rm", file]81subprocess.run(cmd)82
83def add_file(path, name):84file = os.path.join(path, name)85with open(file, "w") as fp:86pass87cmd = ["git", "add", file]88subprocess.run(cmd)89
90covered_unexpected_successes = set()91
92# dynamo_expected_failures93for test in dynamo_expected_failures:94if test in formatted_unexpected_successes:95covered_unexpected_successes.add(test)96remove_file(failures_directory, test)97for test in formatted_new_xfails:98add_file(failures_directory, test)99
100leftover_unexpected_successes = (101formatted_unexpected_successes - covered_unexpected_successes102)103if len(leftover_unexpected_successes) > 0:104print(105"WARNING: we were unable to remove these "106f"{len(leftover_unexpected_successes)} expectedFailures:"107)108for stuff in leftover_unexpected_successes:109print(stuff)110
111# dynamo_skips112for test in dynamo_skips:113if test in formatted_unexpected_skips:114remove_file(skips_directory, test)115for test in extra_dynamo_skips:116if test in formatted_unexpected_skips:117print(118f"WARNING: {test} in dynamo_test_failures.py needs to be removed manually"119)120for test in formatted_new_skips:121add_file(skips_directory, test)122
123
124def get_intersection_and_outside(a_dict, b_dict):125a = set(a_dict.keys())126b = set(b_dict.keys())127intersection = a.intersection(b)128outside = (a.union(b)) - intersection129
130def build_dict(keys):131result = {}132for k in keys:133if k in a_dict:134result[k] = a_dict[k]135else:136result[k] = b_dict[k]137return result138
139return build_dict(intersection), build_dict(outside)140
141
142def update(filename, test_dir, py38_dir, py311_dir, also_remove_skips):143def read_test_results(directory):144xmls = open_test_results(directory)145testcases = get_testcases(xmls)146unexpected_successes = {147key(test): test for test in testcases if is_unexpected_success(test)148}149failures = {key(test): test for test in testcases if is_failure(test)}150passing_skipped_tests = {151key(test): test for test in testcases if is_passing_skipped_test(test)152}153return unexpected_successes, failures, passing_skipped_tests154
155(156py38_unexpected_successes,157py38_failures,158py38_passing_skipped_tests,159) = read_test_results(py38_dir)160(161py311_unexpected_successes,162py311_failures,163py311_passing_skipped_tests,164) = read_test_results(py311_dir)165
166unexpected_successes = {**py38_unexpected_successes, **py311_unexpected_successes}167_, skips = get_intersection_and_outside(168py38_unexpected_successes, py311_unexpected_successes169)170xfails, more_skips = get_intersection_and_outside(py38_failures, py311_failures)171if also_remove_skips:172unexpected_skips, _ = get_intersection_and_outside(173py38_passing_skipped_tests, py311_passing_skipped_tests174)175else:176unexpected_skips = {}177all_skips = {**skips, **more_skips}178print(179f"Discovered {len(unexpected_successes)} new unexpected successes, "180f"{len(xfails)} new xfails, {len(all_skips)} new skips, {len(unexpected_skips)} new unexpected skips"181)182return patch_file(183filename, test_dir, unexpected_successes, xfails, all_skips, unexpected_skips184)185
186
187if __name__ == "__main__":188parser = argparse.ArgumentParser(189prog="update_dynamo_test_failures",190description="Read from logs and update the dynamo_test_failures file",191)192# dynamo_test_failures path193parser.add_argument(194"filename",195nargs="?",196default=str(197Path(__file__).absolute().parent.parent.parent198/ "torch/testing/_internal/dynamo_test_failures.py"199),200help="Optional path to dynamo_test_failures.py",201)202# test path203parser.add_argument(204"test_dir",205nargs="?",206default=str(Path(__file__).absolute().parent.parent.parent / "test"),207help="Optional path to test folder",208)209parser.add_argument(210"commit",211help=(212"The commit sha for the latest commit on a PR from which we will "213"pull CI test results, e.g. 7e5f597aeeba30c390c05f7d316829b3798064a5"214),215)216parser.add_argument(217"--also-remove-skips",218help="Also attempt to remove skips. WARNING: does not guard against test flakiness",219action="store_true",220)221args = parser.parse_args()222assert Path(args.filename).exists(), args.filename223assert Path(args.test_dir).exists(), args.test_dir224dynamo39, dynamo311 = download_reports(args.commit, ("dynamo39", "dynamo311"))225update(args.filename, args.test_dir, dynamo39, dynamo311, args.also_remove_skips)226