7
import xml.etree.ElementTree as ET
8
from collections import defaultdict
9
from types import MethodType
10
from typing import Any, List, Optional, TYPE_CHECKING, Union
13
from _pytest.config import Config, filename_arg
14
from _pytest.config.argparsing import Parser
15
from _pytest.junitxml import _NodeReporter, bin_xml_escape, LogXML
16
from _pytest.python import Module
17
from _pytest.reports import TestReport
18
from _pytest.stash import StashKey
19
from _pytest.terminal import _get_raw_skip_reason
21
from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
25
from _pytest._code.code import ReprFileLocation
29
xml_key = StashKey["LogXMLReruns"]()
30
STEPCURRENT_CACHE_DIR = "cache/stepcurrent"
33
def pytest_addoption(parser: Parser) -> None:
34
group = parser.getgroup("general")
35
group.addoption("--scs", action="store", default=None, dest="stepcurrent_skip")
36
group.addoption("--sc", action="store", default=None, dest="stepcurrent")
37
group.addoption("--rs", action="store", default=None, dest="run_single")
39
parser.addoption("--use-main-module", action="store_true")
40
group = parser.getgroup("terminal reporting")
44
dest="xmlpath_reruns",
46
type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
48
help="create junit-xml style report file at given path.",
51
"--junit-prefix-reruns",
55
help="prepend prefix to classnames in junit-xml output",
58
"junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
61
"junit_logging_reruns",
62
"Write captured log messages to JUnit report: "
63
"one of no|log|system-out|system-err|out-err|all",
67
"junit_log_passing_tests_reruns",
68
"Capture log information for passing tests to JUnit report: ",
73
"junit_duration_report_reruns",
74
"Duration time to report: one of total|call",
78
"junit_family_reruns",
79
"Emit XML for schema: one of legacy|xunit1|xunit2",
82
shard_addoptions(parser)
85
def pytest_configure(config: Config) -> None:
86
xmlpath = config.option.xmlpath_reruns
88
if xmlpath and not hasattr(config, "workerinput"):
89
junit_family = config.getini("junit_family_reruns")
90
config.stash[xml_key] = LogXMLReruns(
92
config.option.junitprefix,
93
config.getini("junit_suite_name_reruns"),
94
config.getini("junit_logging_reruns"),
95
config.getini("junit_duration_report_reruns"),
97
config.getini("junit_log_passing_tests_reruns"),
99
config.pluginmanager.register(config.stash[xml_key])
100
if config.getoption("stepcurrent_skip"):
101
config.option.stepcurrent = config.getoption("stepcurrent_skip")
102
if config.getoption("run_single"):
103
config.option.stepcurrent = config.getoption("run_single")
104
if config.getoption("stepcurrent"):
105
config.pluginmanager.register(StepcurrentPlugin(config), "stepcurrentplugin")
106
if config.getoption("num_shards"):
107
config.pluginmanager.register(PytestShardPlugin(config), "pytestshardplugin")
110
def pytest_unconfigure(config: Config) -> None:
111
xml = config.stash.get(xml_key, None)
113
del config.stash[xml_key]
114
config.pluginmanager.unregister(xml)
117
class _NodeReporterReruns(_NodeReporter):
118
def _prepare_content(self, content: str, header: str) -> str:
121
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
124
tag = ET.Element(jheader)
125
tag.text = bin_xml_escape(content)
128
def append_skipped(self, report: TestReport) -> None:
132
if hasattr(report, "wasxfail"):
134
super().append_skipped(report)
136
assert isinstance(report.longrepr, tuple)
137
filename, lineno, skipreason = report.longrepr
138
if skipreason.startswith("Skipped: "):
139
skipreason = skipreason[9:]
140
details = f"{filename}:{lineno}: {skipreason}"
142
skipped = ET.Element(
143
"skipped", type="pytest.skip", message=bin_xml_escape(skipreason)
145
skipped.text = bin_xml_escape(details)
147
self.write_captured_output(report)
150
class LogXMLReruns(LogXML):
151
def __init__(self, *args, **kwargs):
152
super().__init__(*args, **kwargs)
154
def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
155
if hasattr(report, "wasxfail"):
156
reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
158
assert report.longrepr is not None
159
reprcrash: Optional[ReprFileLocation] = getattr(
160
report.longrepr, "reprcrash", None
162
if reprcrash is not None:
163
message = reprcrash.message
165
message = str(report.longrepr)
166
message = bin_xml_escape(message)
167
reporter._add_simple("rerun", message, str(report.longrepr))
169
def pytest_runtest_logreport(self, report: TestReport) -> None:
170
super().pytest_runtest_logreport(report)
171
if report.outcome == "rerun":
172
reporter = self._opentestcase(report)
173
self.append_rerun(reporter, report)
174
if report.outcome == "skipped":
175
if isinstance(report.longrepr, tuple):
176
fspath, lineno, reason = report.longrepr
177
reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
178
report.longrepr = (fspath, lineno, reason)
180
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
181
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
183
workernode = getattr(report, "node", None)
185
key = nodeid, workernode
187
if key in self.node_reporters:
189
return self.node_reporters[key]
191
reporter = _NodeReporterReruns(nodeid, self)
193
self.node_reporters[key] = reporter
194
self.node_reporters_ordered.append(reporter)
201
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
202
def pytest_terminal_summary(terminalreporter, exitstatus, config):
204
if terminalreporter.config.option.tbstyle != "no":
205
reports = terminalreporter.getreports("rerun")
207
terminalreporter.write_sep("=", "RERUNS")
208
if terminalreporter.config.option.tbstyle == "line":
210
line = terminalreporter._getcrashline(rep)
211
terminalreporter.write_line(line)
214
msg = terminalreporter._getfailureheadline(rep)
215
terminalreporter.write_sep("_", msg, red=True, bold=True)
216
terminalreporter._outrep_summary(rep)
217
terminalreporter._handle_teardown_sections(rep.nodeid)
221
@pytest.hookimpl(tryfirst=True)
222
def pytest_pycollect_makemodule(module_path, path, parent) -> Module:
223
if parent.config.getoption("--use-main-module"):
224
mod = Module.from_parent(parent, path=module_path)
225
mod._getobj = MethodType(lambda x: sys.modules["__main__"], mod)
229
@pytest.hookimpl(hookwrapper=True)
230
def pytest_report_teststatus(report, config):
233
pluggy_result = yield
234
if not isinstance(report, pytest.TestReport):
236
outcome, letter, verbose = pluggy_result.get_result()
238
pluggy_result.force_result(
239
(outcome, letter, f"{verbose} [{report.duration:.4f}s]")
243
@pytest.hookimpl(trylast=True)
244
def pytest_collection_modifyitems(items: List[Any]) -> None:
246
This hook is used when rerunning disabled tests to get rid of all skipped tests
247
instead of running and skipping them N times. This avoids flooding the console
248
and XML outputs with junk. So we want this to run last when collecting tests.
250
rerun_disabled_tests = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
251
if not rerun_disabled_tests:
254
disabled_regex = re.compile(r"(?P<test_name>.+)\s+\([^\.]+\.(?P<test_class>.+)\)")
255
disabled_tests = defaultdict(set)
258
disabled_tests_file = os.getenv("DISABLED_TESTS_FILE", "")
259
if not disabled_tests_file or not os.path.exists(disabled_tests_file):
262
with open(disabled_tests_file) as fp:
263
for disabled_test in json.load(fp):
264
m = disabled_regex.match(disabled_test)
266
test_name = m["test_name"]
267
test_class = m["test_class"]
268
disabled_tests[test_class].add(test_name)
274
test_name = item.name
275
test_class = item.parent.name
278
test_class not in disabled_tests
279
or test_name not in disabled_tests[test_class]
283
cpy = copy.copy(item)
286
filtered_items.append(cpy)
290
items.extend(filtered_items)
293
class StepcurrentPlugin:
296
def __init__(self, config: Config) -> None:
298
self.report_status = ""
299
assert config.cache is not None
300
self.cache: pytest.Cache = config.cache
301
self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
302
self.lastrun: Optional[str] = self.cache.get(self.directory, None)
303
self.initial_val = self.lastrun
304
self.skip: bool = config.getoption("stepcurrent_skip")
305
self.run_single: bool = config.getoption("run_single")
307
def pytest_collection_modifyitems(self, config: Config, items: List[Any]) -> None:
309
self.report_status = "Cannot find last run test, not skipping"
314
for index, item in enumerate(items):
315
if item.nodeid == self.lastrun:
323
if failed_index is None:
324
self.report_status = "previously run test not found, not skipping."
326
self.report_status = f"skipping {failed_index} already run items."
327
deselected = items[:failed_index]
328
del items[:failed_index]
330
self.report_status += f" Running only {items[0].nodeid}"
331
deselected += items[1:]
333
config.hook.pytest_deselected(items=deselected)
335
def pytest_report_collectionfinish(self) -> Optional[str]:
336
if self.config.getoption("verbose") >= 0 and self.report_status:
337
return f"stepcurrent: {self.report_status}"
340
def pytest_runtest_protocol(self, item, nextitem) -> None:
341
self.lastrun = item.nodeid
342
self.cache.set(self.directory, self.lastrun)
344
def pytest_sessionfinish(self, session, exitstatus):
345
if exitstatus == 0 and not self.run_single:
346
self.cache.set(self.directory, self.initial_val)