1
"""Error reproduction utilities for op consistency tests."""
3
from __future__ import annotations
15
import onnxruntime as ort
19
_MISMATCH_MARKDOWN_TEMPLATE = """\
22
The output of ONNX Runtime does not match that of PyTorch when executing test
23
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
25
To recreate this report, use
28
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
39
Shapes: `{input_shapes}`
41
<details><summary>Details</summary>
54
Shape: `{expected_shape}`
56
<details><summary>Details</summary>
68
Shape: `{actual_shape}`
70
<details><summary>Details</summary>
82
<details><summary>Details</summary>
107
def create_mismatch_report(
110
onnx_model: onnx.ModelProto,
117
torch.set_printoptions(threshold=sys.maxsize)
119
error_text = str(error)
120
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
121
short_test_name = test_name.split(".")[-1]
122
diff = difflib.unified_diff(
123
str(actual).splitlines(),
124
str(expected).splitlines(),
129
onnx_model_text = onnx.printer.to_text(onnx_model)
132
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
133
if isinstance(inp, torch.Tensor)
139
OS: {platform.platform()}
140
Python version: {sys.version}
141
onnx=={onnx.__version__}
142
onnxruntime=={ort.__version__}
143
onnxscript=={onnxscript.__version__}
144
numpy=={np.__version__}
145
torch=={torch.__version__}"""
147
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
149
short_test_name=short_test_name,
150
sample_num=sample_num,
151
input_shapes=input_shapes,
155
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
157
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
158
diff="\n".join(diff),
159
error_stack=error_stack,
161
onnx_model_text=onnx_model_text,
164
markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
165
markdown_file_path = save_error_report(markdown_file_name, markdown)
166
print(f"Created reproduction report at {markdown_file_path}")
169
def save_error_report(file_name: str, text: str):
170
reports_dir = pathlib.Path("error_reports")
171
reports_dir.mkdir(parents=True, exist_ok=True)
172
file_path = reports_dir / file_name
173
with open(file_path, "w", encoding="utf-8") as f: