2
"""Unit tests on `torch.onnx.symbolic_helper`."""
5
from torch.onnx import symbolic_helper
6
from torch.onnx._globals import GLOBALS
7
from torch.testing._internal import common_utils
10
class TestHelperFunctions(common_utils.TestCase):
13
self._initial_training_mode = GLOBALS.training_mode
16
GLOBALS.training_mode = self._initial_training_mode
18
@common_utils.parametrize(
19
"op_train_mode,export_mode",
22
[1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
25
[0, torch.onnx.TrainingMode.EVAL],
26
name="modes_match_op_train_mode_0_export_mode_eval",
29
[1, torch.onnx.TrainingMode.TRAINING],
30
name="modes_match_op_train_mode_1_export_mode_training",
34
def test_check_training_mode_does_not_warn_when(
35
self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
37
GLOBALS.training_mode = export_mode
39
lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
42
@common_utils.parametrize(
43
"op_train_mode,export_mode",
46
[0, torch.onnx.TrainingMode.TRAINING],
47
name="modes_do_not_match_op_train_mode_0_export_mode_training",
50
[1, torch.onnx.TrainingMode.EVAL],
51
name="modes_do_not_match_op_train_mode_1_export_mode_eval",
55
def test_check_training_mode_warns_when(
58
export_mode: torch.onnx.TrainingMode,
60
with self.assertWarnsRegex(
61
UserWarning, f"ONNX export mode is set to {export_mode}"
63
GLOBALS.training_mode = export_mode
64
symbolic_helper.check_training_mode(op_train_mode, "testop")
67
common_utils.instantiate_parametrized_tests(TestHelperFunctions)
70
if __name__ == "__main__":
71
common_utils.run_tests()