pytorch

Форк
0
/
test_symbolic_helper.py 
71 строка · 2.2 Кб
1
# Owner(s): ["module: onnx"]
2
"""Unit tests on `torch.onnx.symbolic_helper`."""
3

4
import torch
5
from torch.onnx import symbolic_helper
6
from torch.onnx._globals import GLOBALS
7
from torch.testing._internal import common_utils
8

9

10
class TestHelperFunctions(common_utils.TestCase):
11
    def setUp(self):
12
        super().setUp()
13
        self._initial_training_mode = GLOBALS.training_mode
14

15
    def tearDown(self):
16
        GLOBALS.training_mode = self._initial_training_mode
17

18
    @common_utils.parametrize(
19
        "op_train_mode,export_mode",
20
        [
21
            common_utils.subtest(
22
                [1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
23
            ),
24
            common_utils.subtest(
25
                [0, torch.onnx.TrainingMode.EVAL],
26
                name="modes_match_op_train_mode_0_export_mode_eval",
27
            ),
28
            common_utils.subtest(
29
                [1, torch.onnx.TrainingMode.TRAINING],
30
                name="modes_match_op_train_mode_1_export_mode_training",
31
            ),
32
        ],
33
    )
34
    def test_check_training_mode_does_not_warn_when(
35
        self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
36
    ):
37
        GLOBALS.training_mode = export_mode
38
        self.assertNotWarn(
39
            lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
40
        )
41

42
    @common_utils.parametrize(
43
        "op_train_mode,export_mode",
44
        [
45
            common_utils.subtest(
46
                [0, torch.onnx.TrainingMode.TRAINING],
47
                name="modes_do_not_match_op_train_mode_0_export_mode_training",
48
            ),
49
            common_utils.subtest(
50
                [1, torch.onnx.TrainingMode.EVAL],
51
                name="modes_do_not_match_op_train_mode_1_export_mode_eval",
52
            ),
53
        ],
54
    )
55
    def test_check_training_mode_warns_when(
56
        self,
57
        op_train_mode: int,
58
        export_mode: torch.onnx.TrainingMode,
59
    ):
60
        with self.assertWarnsRegex(
61
            UserWarning, f"ONNX export mode is set to {export_mode}"
62
        ):
63
            GLOBALS.training_mode = export_mode
64
            symbolic_helper.check_training_mode(op_train_mode, "testop")
65

66

67
common_utils.instantiate_parametrized_tests(TestHelperFunctions)
68

69

70
if __name__ == "__main__":
71
    common_utils.run_tests()
72

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.