2
from __future__ import annotations
10
from typing import AbstractSet, Protocol, Tuple
13
from torch.onnx import errors
14
from torch.onnx._internal import diagnostics
15
from torch.onnx._internal.diagnostics import infra
16
from torch.onnx._internal.diagnostics.infra import formatter, sarif
17
from torch.onnx._internal.fx import diagnostics as fx_diagnostics
18
from torch.testing._internal import common_utils, logging_utils
21
class _SarifLogBuilder(Protocol):
22
def sarif_log(self) -> sarif.SarifLog:
26
def _assert_has_diagnostics(
27
sarif_log_builder: _SarifLogBuilder,
28
rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
30
sarif_log = sarif_log_builder.sarif_log()
31
unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs}
33
for run in sarif_log.runs:
34
if run.results is None:
36
for result in run.results:
37
id_level_pair = (result.rule_id, result.level)
38
unseen_pairs.discard(id_level_pair)
39
actual_results.append(id_level_pair)
43
f"Expected diagnostic results of rule id and level pair {unseen_pairs} not found. "
44
f"Actual diagnostic results: {actual_results}"
49
class _RuleCollectionForTest(infra.RuleCollection):
50
rule_without_message_args: infra.Rule = dataclasses.field(
53
"rule-without-message-args",
54
message_default_template="rule message",
59
@contextlib.contextmanager
60
def assert_all_diagnostics(
61
test_suite: unittest.TestCase,
62
sarif_log_builder: _SarifLogBuilder,
63
rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
65
"""Context manager to assert that all diagnostics are emitted.
68
with assert_all_diagnostics(
71
{(rule, infra.Level.Error)},
73
torch.onnx.export(...)
76
test_suite: The test suite instance.
77
sarif_log_builder: The SARIF log builder.
78
rule_level_pairs: A set of rule and level pairs to assert.
84
AssertionError: If not all diagnostics are emitted.
89
except errors.OnnxExporterError:
90
test_suite.assertIn(infra.Level.ERROR, {level for _, level in rule_level_pairs})
92
_assert_has_diagnostics(sarif_log_builder, rule_level_pairs)
96
test_suite: unittest.TestCase,
97
sarif_log_builder: _SarifLogBuilder,
101
"""Context manager to assert that a diagnostic is emitted.
104
with assert_diagnostic(
110
torch.onnx.export(...)
113
test_suite: The test suite instance.
114
sarif_log_builder: The SARIF log builder.
115
rule: The rule to assert.
116
level: The level to assert.
122
AssertionError: If the diagnostic is not emitted.
125
return assert_all_diagnostics(test_suite, sarif_log_builder, {(rule, level)})
128
class TestDynamoOnnxDiagnostics(common_utils.TestCase):
129
"""Test cases for diagnostics emitted by the Dynamo ONNX export code."""
132
self.diagnostic_context = fx_diagnostics.DiagnosticContext("dynamo_export", "")
133
self.rules = _RuleCollectionForTest()
134
return super().setUp()
136
def test_log_is_recorded_in_sarif_additional_messages_according_to_diagnostic_options_verbosity_level(
145
for verbosity_level in logging_levels:
146
self.diagnostic_context.options.verbosity_level = verbosity_level
147
with self.diagnostic_context:
148
diagnostic = fx_diagnostics.Diagnostic(
149
self.rules.rule_without_message_args, infra.Level.NONE
151
additional_messages_count = len(diagnostic.additional_messages)
152
for log_level in logging_levels:
153
diagnostic.log(level=log_level, message="log message")
154
if log_level >= verbosity_level:
156
len(diagnostic.additional_messages),
157
additional_messages_count,
158
f"Additional message should be recorded when log level is {log_level} "
159
f"and verbosity level is {verbosity_level}",
163
len(diagnostic.additional_messages),
164
additional_messages_count,
165
f"Additional message should not be recorded when log level is "
166
f"{log_level} and verbosity level is {verbosity_level}",
169
def test_torch_logs_environment_variable_precedes_diagnostic_options_verbosity_level(
172
self.diagnostic_context.options.verbosity_level = logging.ERROR
173
with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
174
diagnostic = fx_diagnostics.Diagnostic(
175
self.rules.rule_without_message_args, infra.Level.NONE
177
additional_messages_count = len(diagnostic.additional_messages)
178
diagnostic.debug("message")
180
len(diagnostic.additional_messages), additional_messages_count
183
def test_log_is_not_emitted_to_terminal_when_log_artifact_is_not_enabled(self):
184
self.diagnostic_context.options.verbosity_level = logging.INFO
185
with self.diagnostic_context:
186
diagnostic = fx_diagnostics.Diagnostic(
187
self.rules.rule_without_message_args, infra.Level.NONE
190
with self.assertLogs(
191
diagnostic.logger, level=logging.INFO
192
) as assert_log_context:
193
diagnostic.info("message")
197
diagnostic.logger.log(logging.ERROR, "dummy message")
199
self.assertEqual(len(assert_log_context.records), 1)
201
def test_log_is_emitted_to_terminal_when_log_artifact_is_enabled(self):
202
self.diagnostic_context.options.verbosity_level = logging.INFO
204
with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
205
diagnostic = fx_diagnostics.Diagnostic(
206
self.rules.rule_without_message_args, infra.Level.NONE
209
with self.assertLogs(diagnostic.logger, level=logging.INFO):
210
diagnostic.info("message")
212
def test_diagnostic_log_emit_correctly_formatted_string(self):
213
verbosity_level = logging.INFO
214
self.diagnostic_context.options.verbosity_level = verbosity_level
215
with self.diagnostic_context:
216
diagnostic = fx_diagnostics.Diagnostic(
217
self.rules.rule_without_message_args, infra.Level.NOTE
222
formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
224
self.assertIn("hello world", diagnostic.additional_messages)
226
def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
229
with self.diagnostic_context:
232
diagnostic = infra.Diagnostic(
233
self.rules.rule_without_message_args, infra.Level.NOTE
235
with self.assertRaises(TypeError):
236
self.diagnostic_context.log(diagnostic)
239
class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
240
"""Test cases for diagnostics emitted by the TorchScript ONNX export code."""
243
engine = diagnostics.engine
245
self._sample_rule = diagnostics.rules.missing_custom_symbolic_function
248
def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
250
) -> diagnostics.TorchScriptOnnxExportDiagnostic:
251
class CustomAdd(torch.autograd.Function):
253
def forward(ctx, x, y):
257
def symbolic(g, x, y):
258
return g.op("custom::CustomAdd", x, y)
260
class M(torch.nn.Module):
261
def forward(self, x):
262
return CustomAdd.apply(x, x)
265
rule = diagnostics.rules.node_missing_onnx_shape_inference
266
torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO())
268
context = diagnostics.engine.contexts[-1]
269
for diagnostic in context.diagnostics:
271
diagnostic.rule == rule
272
and diagnostic.level == diagnostics.levels.WARNING
275
diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic
277
raise AssertionError("No diagnostic found.")
279
def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
280
with self.assertRaises(AssertionError):
281
with assert_diagnostic(
284
diagnostics.rules.node_missing_onnx_shape_inference,
285
diagnostics.levels.WARNING,
289
def test_cpp_diagnose_emits_warning(self):
290
with assert_diagnostic(
293
diagnostics.rules.node_missing_onnx_shape_inference,
294
diagnostics.levels.WARNING,
297
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
299
def test_py_diagnose_emits_error(self):
300
class M(torch.nn.Module):
301
def forward(self, x):
302
return torch.diagonal(x)
304
with assert_diagnostic(
307
diagnostics.rules.operator_supported_in_newer_opset_version,
308
diagnostics.levels.ERROR,
318
def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
321
sample_level = diagnostics.levels.ERROR
322
with assert_diagnostic(
328
diagnostic = infra.Diagnostic(self._sample_rule, sample_level)
329
diagnostics.export_context().log(diagnostic)
331
def test_diagnostics_records_python_call_stack(self):
332
diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE)
334
stack = diagnostic.python_call_stack
335
assert stack is not None
336
self.assertGreater(len(stack.frames), 0)
337
frame = stack.frames[0]
338
assert frame.location.snippet is not None
339
self.assertIn("self._sample_rule", frame.location.snippet)
340
assert frame.location.uri is not None
341
self.assertIn("test_diagnostics.py", frame.location.uri)
343
def test_diagnostics_records_cpp_call_stack(self):
345
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
347
stack = diagnostic.cpp_call_stack
348
assert stack is not None
349
self.assertGreater(len(stack.frames), 0)
350
frame_messages = [frame.location.message for frame in stack.frames]
355
isinstance(message, str) and "torch::jit::NodeToONNX" in message
356
for message in frame_messages
361
@common_utils.instantiate_parametrized_tests
362
class TestDiagnosticsInfra(common_utils.TestCase):
363
"""Test cases for diagnostics infra."""
366
self.rules = _RuleCollectionForTest()
367
with contextlib.ExitStack() as stack:
368
self.context: infra.DiagnosticContext[
370
] = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
371
self.addCleanup(stack.pop_all().close)
372
return super().setUp()
374
def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
375
custom_rules = infra.RuleCollection.custom_collection_from_list(
376
"CustomRuleCollection",
381
message_default_template="custom rule message",
386
message_default_template="custom rule message 2",
391
with assert_all_diagnostics(
395
(custom_rules.custom_rule, infra.Level.WARNING),
396
(custom_rules.custom_rule_2, infra.Level.ERROR),
399
diagnostic1 = infra.Diagnostic(
400
custom_rules.custom_rule, infra.Level.WARNING
402
self.context.log(diagnostic1)
404
diagnostic2 = infra.Diagnostic(
405
custom_rules.custom_rule_2, infra.Level.ERROR
407
self.context.log(diagnostic2)
409
def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level(
412
verbosity_level = logging.INFO
413
self.context.options.verbosity_level = verbosity_level
415
diagnostic = infra.Diagnostic(
416
self.rules.rule_without_message_args, infra.Level.NOTE
419
with self.assertLogs(
420
diagnostic.logger, level=verbosity_level
421
) as assert_log_context:
422
diagnostic.log(logging.DEBUG, "debug message")
426
diagnostic.log(logging.INFO, "info message")
428
for record in assert_log_context.records:
429
self.assertGreaterEqual(record.levelno, logging.INFO)
432
message.find("debug message") >= 0
433
for message in diagnostic.additional_messages
437
def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level(
440
verbosity_level = logging.INFO
441
self.context.options.verbosity_level = verbosity_level
443
diagnostic = infra.Diagnostic(
444
self.rules.rule_without_message_args, infra.Level.NOTE
447
level_message_pairs = [
448
(logging.INFO, "info message"),
449
(logging.WARNING, "warning message"),
450
(logging.ERROR, "error message"),
453
for level, message in level_message_pairs:
454
with self.assertLogs(diagnostic.logger, level=verbosity_level):
455
diagnostic.log(level, message)
459
message.find(message) >= 0
460
for message in diagnostic.additional_messages
464
@common_utils.parametrize(
465
"log_api, log_level",
467
("debug", logging.DEBUG),
468
("info", logging.INFO),
469
("warning", logging.WARNING),
470
("error", logging.ERROR),
473
def test_diagnostic_log_is_emitted_according_to_api_level_and_diagnostic_options_verbosity_level(
474
self, log_api: str, log_level: int
476
verbosity_level = logging.INFO
477
self.context.options.verbosity_level = verbosity_level
479
diagnostic = infra.Diagnostic(
480
self.rules.rule_without_message_args, infra.Level.NOTE
483
message = "log message"
484
with self.assertLogs(
485
diagnostic.logger, level=verbosity_level
486
) as assert_log_context:
487
getattr(diagnostic, log_api)(message)
491
diagnostic.log(logging.ERROR, "dummy message")
493
for record in assert_log_context.records:
494
self.assertGreaterEqual(record.levelno, logging.INFO)
496
if log_level >= verbosity_level:
497
self.assertIn(message, diagnostic.additional_messages)
499
self.assertNotIn(message, diagnostic.additional_messages)
501
def test_diagnostic_log_lazy_string_is_not_evaluated_when_level_less_than_diagnostic_options_verbosity_level(
504
verbosity_level = logging.INFO
505
self.context.options.verbosity_level = verbosity_level
507
diagnostic = infra.Diagnostic(
508
self.rules.rule_without_message_args, infra.Level.NOTE
513
def expensive_formatting_function() -> str:
515
nonlocal reference_val
517
return f"expensive formatting {reference_val}"
520
diagnostic.debug("%s", formatter.LazyString(expensive_formatting_function))
524
"expensive_formatting_function should not be evaluated after being wrapped under LazyString",
527
def test_diagnostic_log_lazy_string_is_evaluated_once_when_level_not_less_than_diagnostic_options_verbosity_level(
530
verbosity_level = logging.INFO
531
self.context.options.verbosity_level = verbosity_level
533
diagnostic = infra.Diagnostic(
534
self.rules.rule_without_message_args, infra.Level.NOTE
539
def expensive_formatting_function() -> str:
541
nonlocal reference_val
543
return f"expensive formatting {reference_val}"
546
diagnostic.info("%s", formatter.LazyString(expensive_formatting_function))
550
"expensive_formatting_function should only be evaluated once after being wrapped under LazyString",
553
def test_diagnostic_log_emit_correctly_formatted_string(self):
554
verbosity_level = logging.INFO
555
self.context.options.verbosity_level = verbosity_level
557
diagnostic = infra.Diagnostic(
558
self.rules.rule_without_message_args, infra.Level.NOTE
563
formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
565
self.assertIn("hello world", diagnostic.additional_messages)
567
def test_diagnostic_nested_log_section_emits_messages_with_correct_section_title_indentation(
570
verbosity_level = logging.INFO
571
self.context.options.verbosity_level = verbosity_level
573
diagnostic = infra.Diagnostic(
574
self.rules.rule_without_message_args, infra.Level.NOTE
577
with diagnostic.log_section(logging.INFO, "My Section"):
578
diagnostic.log(logging.INFO, "My Message")
579
with diagnostic.log_section(logging.INFO, "My Subsection"):
580
diagnostic.log(logging.INFO, "My Submessage")
582
with diagnostic.log_section(logging.INFO, "My Section 2"):
583
diagnostic.log(logging.INFO, "My Message 2")
585
self.assertIn("## My Section", diagnostic.additional_messages)
586
self.assertIn("### My Subsection", diagnostic.additional_messages)
587
self.assertIn("## My Section 2", diagnostic.additional_messages)
589
def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_message(
592
verbosity_level = logging.INFO
593
self.context.options.verbosity_level = verbosity_level
596
raise ValueError("original exception")
597
except ValueError as e:
598
diagnostic = infra.Diagnostic(
599
self.rules.rule_without_message_args, infra.Level.NOTE
601
diagnostic.log_source_exception(logging.ERROR, e)
603
diagnostic_message = "\n".join(diagnostic.additional_messages)
605
self.assertIn("ValueError: original exception", diagnostic_message)
606
self.assertIn("Traceback (most recent call last):", diagnostic_message)
608
def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
612
with self.assertRaises(TypeError):
615
self.context.log("This is a str message.")
617
def test_diagnostic_context_raises_if_diagnostic_is_error(self):
618
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
619
self.context.log_and_raise_if_error(
621
self.rules.rule_without_message_args, infra.Level.ERROR
625
def test_diagnostic_context_raises_original_exception_from_diagnostic_created_from_it(
628
with self.assertRaises(ValueError):
630
raise ValueError("original exception")
631
except ValueError as e:
632
diagnostic = infra.Diagnostic(
633
self.rules.rule_without_message_args, infra.Level.ERROR
635
diagnostic.log_source_exception(logging.ERROR, e)
636
self.context.log_and_raise_if_error(diagnostic)
638
def test_diagnostic_context_raises_if_diagnostic_is_warning_and_warnings_as_errors_is_true(
641
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
642
self.context.options.warnings_as_errors = True
643
self.context.log_and_raise_if_error(
645
self.rules.rule_without_message_args, infra.Level.WARNING
650
if __name__ == "__main__":
651
common_utils.run_tests()