pytorch

Форк
0
/
error_reproduction.py 
176 строк · 3.4 Кб
1
"""Error reproduction utilities for op consistency tests."""
2

3
from __future__ import annotations
4

5
import difflib
6
import pathlib
7
import platform
8
import sys
9
import time
10
import traceback
11

12
import numpy as np
13

14
import onnx
15
import onnxruntime as ort
16
import onnxscript
17
import torch
18

19
_MISMATCH_MARKDOWN_TEMPLATE = """\
20
### Summary
21

22
The output of ONNX Runtime does not match that of PyTorch when executing test
23
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
24

25
To recreate this report, use
26

27
```bash
28
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
29
```
30

31
### ONNX Model
32

33
```
34
{onnx_model_text}
35
```
36

37
### Inputs
38

39
Shapes: `{input_shapes}`
40

41
<details><summary>Details</summary>
42
<p>
43

44
```python
45
kwargs = {kwargs}
46
inputs = {inputs}
47
```
48

49
</p>
50
</details>
51

52
### Expected output
53

54
Shape: `{expected_shape}`
55

56
<details><summary>Details</summary>
57
<p>
58

59
```python
60
expected = {expected}
61
```
62

63
</p>
64
</details>
65

66
### Actual output
67

68
Shape: `{actual_shape}`
69

70
<details><summary>Details</summary>
71
<p>
72

73
```python
74
actual = {actual}
75
```
76

77
</p>
78
</details>
79

80
### Difference
81

82
<details><summary>Details</summary>
83
<p>
84

85
```diff
86
{diff}
87
```
88

89
</p>
90
</details>
91

92
### Full error stack
93

94
```
95
{error_stack}
96
```
97

98
### Environment
99

100
```
101
{sys_info}
102
```
103

104
"""
105

106

107
def create_mismatch_report(
108
    test_name: str,
109
    sample_num: int,
110
    onnx_model: onnx.ModelProto,
111
    inputs,
112
    kwargs,
113
    actual,
114
    expected,
115
    error: Exception,
116
) -> None:
117
    torch.set_printoptions(threshold=sys.maxsize)
118

119
    error_text = str(error)
120
    error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
121
    short_test_name = test_name.split(".")[-1]
122
    diff = difflib.unified_diff(
123
        str(actual).splitlines(),
124
        str(expected).splitlines(),
125
        fromfile="actual",
126
        tofile="expected",
127
        lineterm="",
128
    )
129
    onnx_model_text = onnx.printer.to_text(onnx_model)
130
    input_shapes = repr(
131
        [
132
            f"Tensor<{inp.shape}, dtype={inp.dtype}>"
133
            if isinstance(inp, torch.Tensor)
134
            else inp
135
            for inp in inputs
136
        ]
137
    )
138
    sys_info = f"""\
139
OS: {platform.platform()}
140
Python version: {sys.version}
141
onnx=={onnx.__version__}
142
onnxruntime=={ort.__version__}
143
onnxscript=={onnxscript.__version__}
144
numpy=={np.__version__}
145
torch=={torch.__version__}"""
146

147
    markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
148
        test_name=test_name,
149
        short_test_name=short_test_name,
150
        sample_num=sample_num,
151
        input_shapes=input_shapes,
152
        inputs=inputs,
153
        kwargs=kwargs,
154
        expected=expected,
155
        expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
156
        actual=actual,
157
        actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
158
        diff="\n".join(diff),
159
        error_stack=error_stack,
160
        sys_info=sys_info,
161
        onnx_model_text=onnx_model_text,
162
    )
163

164
    markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
165
    markdown_file_path = save_error_report(markdown_file_name, markdown)
166
    print(f"Created reproduction report at {markdown_file_path}")
167

168

169
def save_error_report(file_name: str, text: str):
170
    reports_dir = pathlib.Path("error_reports")
171
    reports_dir.mkdir(parents=True, exist_ok=True)
172
    file_path = reports_dir / file_name
173
    with open(file_path, "w", encoding="utf-8") as f:
174
        f.write(text)
175

176
    return file_path
177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.