1
from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
2
from _pytest.terminal import _get_raw_skip_reason
3
from _pytest.stash import StashKey
4
from _pytest.reports import TestReport
5
from _pytest.config.argparsing import Parser
6
from _pytest.config import filename_arg
7
from _pytest.config import Config
8
from _pytest._code.code import ReprFileLocation
9
from _pytest.python import Module
10
from typing import Any, List, Union
11
from typing import Optional
12
from types import MethodType
13
import xml.etree.ElementTree as ET
21
from collections import defaultdict
22
from pytest_shard_custom import PytestShardPlugin, pytest_addoptions as shard_addoptions
26
xml_key = StashKey["LogXMLReruns"]()
27
STEPCURRENT_CACHE_DIR = "cache/stepcurrent"
30
def pytest_addoption(parser: Parser) -> None:
31
group = parser.getgroup("general")
36
dest="stepcurrent_skip",
45
parser.addoption("--use-main-module", action='store_true')
46
group = parser.getgroup("terminal reporting")
50
dest="xmlpath_reruns",
52
type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
54
help="create junit-xml style report file at given path.",
57
"--junit-prefix-reruns",
61
help="prepend prefix to classnames in junit-xml output",
64
"junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
67
"junit_logging_reruns",
68
"Write captured log messages to JUnit report: "
69
"one of no|log|system-out|system-err|out-err|all",
73
"junit_log_passing_tests_reruns",
74
"Capture log information for passing tests to JUnit report: ",
79
"junit_duration_report_reruns",
80
"Duration time to report: one of total|call",
84
"junit_family_reruns",
85
"Emit XML for schema: one of legacy|xunit1|xunit2",
88
shard_addoptions(parser)
91
def pytest_configure(config: Config) -> None:
92
xmlpath = config.option.xmlpath_reruns
94
if xmlpath and not hasattr(config, "workerinput"):
95
junit_family = config.getini("junit_family_reruns")
96
config.stash[xml_key] = LogXMLReruns(
98
config.option.junitprefix,
99
config.getini("junit_suite_name_reruns"),
100
config.getini("junit_logging_reruns"),
101
config.getini("junit_duration_report_reruns"),
103
config.getini("junit_log_passing_tests_reruns"),
105
config.pluginmanager.register(config.stash[xml_key])
106
if config.getoption("stepcurrent_skip"):
107
config.option.stepcurrent = config.getoption("stepcurrent_skip")
108
if config.getoption("stepcurrent"):
109
config.pluginmanager.register(StepcurrentPlugin(config), "stepcurrentplugin")
110
if config.getoption("num_shards"):
111
config.pluginmanager.register(PytestShardPlugin(config), "pytestshardplugin")
114
def pytest_unconfigure(config: Config) -> None:
115
xml = config.stash.get(xml_key, None)
117
del config.stash[xml_key]
118
config.pluginmanager.unregister(xml)
121
class _NodeReporterReruns(_NodeReporter):
122
def _prepare_content(self, content: str, header: str) -> str:
125
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
128
tag = ET.Element(jheader)
129
tag.text = bin_xml_escape(content)
132
def append_skipped(self, report: TestReport) -> None:
136
if hasattr(report, "wasxfail"):
138
super().append_skipped(report)
140
assert isinstance(report.longrepr, tuple)
141
filename, lineno, skipreason = report.longrepr
142
if skipreason.startswith("Skipped: "):
143
skipreason = skipreason[9:]
144
details = f"{filename}:{lineno}: {skipreason}"
146
skipped = ET.Element("skipped", type="pytest.skip", message=bin_xml_escape(skipreason))
147
skipped.text = bin_xml_escape(details)
149
self.write_captured_output(report)
151
class LogXMLReruns(LogXML):
152
def __init__(self, *args, **kwargs):
153
super().__init__(*args, **kwargs)
155
def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
156
if hasattr(report, "wasxfail"):
157
reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
159
assert report.longrepr is not None
160
reprcrash: Optional[ReprFileLocation] = getattr(
161
report.longrepr, "reprcrash", None
163
if reprcrash is not None:
164
message = reprcrash.message
166
message = str(report.longrepr)
167
message = bin_xml_escape(message)
168
reporter._add_simple("rerun", message, str(report.longrepr))
170
def pytest_runtest_logreport(self, report: TestReport) -> None:
171
super().pytest_runtest_logreport(report)
172
if report.outcome == "rerun":
173
reporter = self._opentestcase(report)
174
self.append_rerun(reporter, report)
175
if report.outcome == "skipped":
176
if isinstance(report.longrepr, tuple):
177
fspath, lineno, reason = report.longrepr
178
reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
179
report.longrepr = (fspath, lineno, reason)
181
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
182
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
184
workernode = getattr(report, "node", None)
186
key = nodeid, workernode
188
if key in self.node_reporters:
190
return self.node_reporters[key]
192
reporter = _NodeReporterReruns(nodeid, self)
194
self.node_reporters[key] = reporter
195
self.node_reporters_ordered.append(reporter)
202
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
203
def pytest_terminal_summary(terminalreporter, exitstatus, config):
205
if terminalreporter.config.option.tbstyle != "no":
206
reports = terminalreporter.getreports("rerun")
208
terminalreporter.write_sep("=", "RERUNS")
209
if terminalreporter.config.option.tbstyle == "line":
211
line = terminalreporter._getcrashline(rep)
212
terminalreporter.write_line(line)
215
msg = terminalreporter._getfailureheadline(rep)
216
terminalreporter.write_sep("_", msg, red=True, bold=True)
217
terminalreporter._outrep_summary(rep)
218
terminalreporter._handle_teardown_sections(rep.nodeid)
222
@pytest.hookimpl(tryfirst=True)
223
def pytest_pycollect_makemodule(module_path, path, parent) -> Module:
224
if parent.config.getoption("--use-main-module"):
225
mod = Module.from_parent(parent, path=module_path)
226
mod._getobj = MethodType(lambda x: sys.modules['__main__'], mod)
230
@pytest.hookimpl(hookwrapper=True)
231
def pytest_report_teststatus(report, config):
234
pluggy_result = yield
235
if not isinstance(report, pytest.TestReport):
237
outcome, letter, verbose = pluggy_result.get_result()
239
pluggy_result.force_result(
240
(outcome, letter, f"{verbose} [{report.duration:.4f}s]")
244
@pytest.hookimpl(trylast=True)
245
def pytest_collection_modifyitems(items: List[Any]) -> None:
247
This hook is used when rerunning disabled tests to get rid of all skipped tests
248
instead of running and skipping them N times. This avoids flooding the console
249
and XML outputs with junk. So we want this to run last when collecting tests.
251
rerun_disabled_tests = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
252
if not rerun_disabled_tests:
255
disabled_regex = re.compile(r"(?P<test_name>.+)\s+\([^\.]+\.(?P<test_class>.+)\)")
256
disabled_tests = defaultdict(set)
259
disabled_tests_file = os.getenv("DISABLED_TESTS_FILE", "")
260
if not disabled_tests_file or not os.path.exists(disabled_tests_file):
263
with open(disabled_tests_file) as fp:
264
for disabled_test in json.load(fp):
265
m = disabled_regex.match(disabled_test)
267
test_name = m["test_name"]
268
test_class = m["test_class"]
269
disabled_tests[test_class].add(test_name)
275
test_name = item.name
276
test_class = item.parent.name
278
if test_class not in disabled_tests or test_name not in disabled_tests[test_class]:
281
cpy = copy.copy(item)
284
filtered_items.append(cpy)
288
items.extend(filtered_items)
291
class StepcurrentPlugin:
294
def __init__(self, config: Config) -> None:
296
self.report_status = ""
297
assert config.cache is not None
298
self.cache: pytest.Cache = config.cache
299
self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
300
self.lastrun: Optional[str] = self.cache.get(self.directory, None)
301
self.initial_val = self.lastrun
302
self.skip: bool = config.getoption("stepcurrent_skip")
304
def pytest_collection_modifyitems(self, config: Config, items: List[Any]) -> None:
306
self.report_status = "Cannot find last run test, not skipping"
311
for index, item in enumerate(items):
312
if item.nodeid == self.lastrun:
320
if failed_index is None:
321
self.report_status = "previously run test not found, not skipping."
323
self.report_status = f"skipping {failed_index} already run items."
324
deselected = items[:failed_index]
325
del items[:failed_index]
326
config.hook.pytest_deselected(items=deselected)
328
def pytest_report_collectionfinish(self) -> Optional[str]:
329
if self.config.getoption("verbose") >= 0 and self.report_status:
330
return f"stepcurrent: {self.report_status}"
333
def pytest_runtest_protocol(self, item, nextitem) -> None:
334
self.lastrun = item.nodeid
335
self.cache.set(self.directory, self.lastrun)
337
def pytest_sessionfinish(self, session, exitstatus):
339
self.cache.set(self.directory, self.initial_val)