pytorch

Форк
0
/
check_disabled_tests.py 
286 строк · 8.9 Кб
1
import argparse
2
import json
3
import os
4
import xml.etree.ElementTree as ET
5
from pathlib import Path
6
from tempfile import TemporaryDirectory
7
from typing import Any, Dict, Generator, Tuple
8

9
from tools.stats.upload_stats_lib import (
10
    download_gha_artifacts,
11
    download_s3_artifacts,
12
    is_rerun_disabled_tests,
13
    unzip,
14
    upload_workflow_stats_to_s3,
15
)
16
from tools.stats.upload_test_stats import process_xml_element
17

18
TESTCASE_TAG = "testcase"
19
SEPARATOR = ";"
20

21

22
def process_report(
23
    report: Path,
24
) -> Dict[str, Dict[str, int]]:
25
    """
26
    Return a list of disabled tests that should be re-enabled and those that are still
27
    flaky (failed or skipped)
28
    """
29
    root = ET.parse(report)
30

31
    # All rerun tests from a report are grouped here:
32
    #
33
    # * Success test should be re-enable if it's green after rerunning in all platforms
34
    #   where it is currently disabled
35
    # * Failures from pytest because pytest-flakefinder is used to run the same test
36
    #   multiple times, some could fails
37
    # * Skipped tests from unittest
38
    #
39
    # We want to keep track of how many times the test fails (num_red) or passes (num_green)
40
    all_tests: Dict[str, Dict[str, int]] = {}
41

42
    for test_case in root.iter(TESTCASE_TAG):
43
        parsed_test_case = process_xml_element(test_case)
44

45
        # Under --rerun-disabled-tests mode, a test is skipped when:
46
        # * it's skipped explicitly inside PyTorch code
47
        # * it's skipped because it's a normal enabled test
48
        # * or it's falky (num_red > 0 and num_green > 0)
49
        # * or it's failing (num_red > 0 and num_green == 0)
50
        #
51
        # We care only about the latter two here
52
        skipped = parsed_test_case.get("skipped", None)
53

54
        # NB: Regular ONNX tests could return a list of subskips here where each item in the
55
        # list is a skipped message.  In the context of rerunning disabled tests, we could
56
        # ignore this case as returning a list of subskips only happens when tests are run
57
        # normally
58
        if skipped and (
59
            type(skipped) is list or "num_red" not in skipped.get("message", "")
60
        ):
61
            continue
62

63
        name = parsed_test_case.get("name", "")
64
        classname = parsed_test_case.get("classname", "")
65
        filename = parsed_test_case.get("file", "")
66

67
        if not name or not classname or not filename:
68
            continue
69

70
        # Check if the test is a failure
71
        failure = parsed_test_case.get("failure", None)
72

73
        disabled_test_id = SEPARATOR.join([name, classname, filename])
74
        if disabled_test_id not in all_tests:
75
            all_tests[disabled_test_id] = {
76
                "num_green": 0,
77
                "num_red": 0,
78
            }
79

80
        # Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's
81
        # counted as a success. Otherwise, it's still flaky or failing
82
        if skipped:
83
            try:
84
                stats = json.loads(skipped.get("message", ""))
85
            except json.JSONDecodeError:
86
                stats = {}
87

88
            all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0)
89
            all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0)
90
        elif failure:
91
            # As a failure, increase the failure count
92
            all_tests[disabled_test_id]["num_red"] += 1
93
        else:
94
            all_tests[disabled_test_id]["num_green"] += 1
95

96
    return all_tests
97

98

99
def get_test_reports(
100
    repo: str, workflow_run_id: int, workflow_run_attempt: int
101
) -> Generator[Path, None, None]:
102
    """
103
    Gather all the test reports from S3 and GHA. It is currently not possible to guess which
104
    test reports are from rerun_disabled_tests workflow because the name doesn't include the
105
    test config. So, all reports will need to be downloaded and examined
106
    """
107
    with TemporaryDirectory() as temp_dir:
108
        print("Using temporary directory:", temp_dir)
109
        os.chdir(temp_dir)
110

111
        artifact_paths = download_s3_artifacts(
112
            "test-reports", workflow_run_id, workflow_run_attempt
113
        )
114
        for path in artifact_paths:
115
            unzip(path)
116

117
        artifact_paths = download_gha_artifacts(
118
            "test-report", workflow_run_id, workflow_run_attempt
119
        )
120
        for path in artifact_paths:
121
            unzip(path)
122

123
        yield from Path(".").glob("**/*.xml")
124

125

126
def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]:
127
    """
128
    Follow flaky bot convention here, if that changes, this will also need to be updated
129
    """
130
    name, classname, filename = test_id.split(SEPARATOR)
131
    return f"{name} (__main__.{classname})", name, classname, filename
132

133

134
def prepare_record(
135
    workflow_id: int,
136
    workflow_run_attempt: int,
137
    name: str,
138
    classname: str,
139
    filename: str,
140
    flaky: bool,
141
    num_red: int = 0,
142
    num_green: int = 0,
143
) -> Tuple[Any, Dict[str, Any]]:
144
    """
145
    Prepare the record to save onto S3
146
    """
147
    key = (
148
        workflow_id,
149
        workflow_run_attempt,
150
        name,
151
        classname,
152
        filename,
153
    )
154

155
    record = {
156
        "workflow_id": workflow_id,
157
        "workflow_run_attempt": workflow_run_attempt,
158
        "name": name,
159
        "classname": classname,
160
        "filename": filename,
161
        "flaky": flaky,
162
        "num_green": num_green,
163
        "num_red": num_red,
164
    }
165

166
    return key, record
167

168

169
def save_results(
170
    workflow_id: int,
171
    workflow_run_attempt: int,
172
    all_tests: Dict[str, Dict[str, int]],
173
) -> None:
174
    """
175
    Save the result to S3, so it can go to Rockset
176
    """
177
    should_be_enabled_tests = {
178
        name: stats
179
        for name, stats in all_tests.items()
180
        if "num_green" in stats
181
        and stats["num_green"]
182
        and "num_red" in stats
183
        and stats["num_red"] == 0
184
    }
185
    still_flaky_tests = {
186
        name: stats
187
        for name, stats in all_tests.items()
188
        if name not in should_be_enabled_tests
189
    }
190

191
    records = {}
192
    for test_id, stats in all_tests.items():
193
        num_green = stats.get("num_green", 0)
194
        num_red = stats.get("num_red", 0)
195
        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
196

197
        key, record = prepare_record(
198
            workflow_id=workflow_id,
199
            workflow_run_attempt=workflow_run_attempt,
200
            name=name,
201
            classname=classname,
202
            filename=filename,
203
            flaky=test_id in still_flaky_tests,
204
            num_green=num_green,
205
            num_red=num_red,
206
        )
207
        records[key] = record
208

209
    # Log the results
210
    print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:")
211
    for test_id, stats in should_be_enabled_tests.items():
212
        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
213
        print(f"  {disabled_test_name} from {filename}")
214

215
    print(f"The following {len(still_flaky_tests)} are still flaky:")
216
    for test_id, stats in still_flaky_tests.items():
217
        num_green = stats.get("num_green", 0)
218
        num_red = stats.get("num_red", 0)
219

220
        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
221
        print(
222
            f"  {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}"
223
        )
224

225
    upload_workflow_stats_to_s3(
226
        workflow_id,
227
        workflow_run_attempt,
228
        "rerun_disabled_tests",
229
        list(records.values()),
230
    )
231

232

233
def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
234
    """
235
    Find the list of all disabled tests that should be re-enabled
236
    """
237
    # Aggregated across all jobs
238
    all_tests: Dict[str, Dict[str, int]] = {}
239

240
    for report in get_test_reports(
241
        args.repo, args.workflow_run_id, args.workflow_run_attempt
242
    ):
243
        tests = process_report(report)
244

245
        # The scheduled workflow has both rerun disabled tests and memory leak check jobs.
246
        # We are only interested in the former here
247
        if not is_rerun_disabled_tests(tests):
248
            continue
249

250
        for name, stats in tests.items():
251
            if name not in all_tests:
252
                all_tests[name] = stats.copy()
253
            else:
254
                all_tests[name]["num_green"] += stats.get("num_green", 0)
255
                all_tests[name]["num_red"] += stats.get("num_red", 0)
256

257
    save_results(
258
        workflow_run_id,
259
        workflow_run_attempt,
260
        all_tests,
261
    )
262

263

264
if __name__ == "__main__":
265
    parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3")
266
    parser.add_argument(
267
        "--workflow-run-id",
268
        type=int,
269
        required=True,
270
        help="id of the workflow to get artifacts from",
271
    )
272
    parser.add_argument(
273
        "--workflow-run-attempt",
274
        type=int,
275
        required=True,
276
        help="which retry of the workflow this is",
277
    )
278
    parser.add_argument(
279
        "--repo",
280
        type=str,
281
        required=True,
282
        help="which GitHub repo this workflow run belongs to",
283
    )
284

285
    args = parser.parse_args()
286
    main(args.repo, args.workflow_run_id, args.workflow_run_attempt)
287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.