pytorch

Форк
0
/
test_cpp_api_parity.py 
93 строки · 3.2 Кб
1
# Owner(s): ["module: cpp"]
2

3

4
import os
5

6
from cpp_api_parity import (
7
    functional_impl_check,
8
    module_impl_check,
9
    sample_functional,
10
    sample_module,
11
)
12
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
13
from cpp_api_parity.utils import is_torch_nn_functional_test
14

15
import torch
16
import torch.testing._internal.common_nn as common_nn
17
import torch.testing._internal.common_utils as common
18

19

20
# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
21
PRINT_CPP_SOURCE = False
22

23
devices = ["cpu", "cuda"]
24

25
PARITY_TABLE_PATH = os.path.join(
26
    os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md"
27
)
28

29
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
30

31

32
@torch.testing._internal.common_utils.markDynamoStrictTest
33
class TestCppApiParity(common.TestCase):
34
    module_test_params_map = {}
35
    functional_test_params_map = {}
36

37

38
expected_test_params_dicts = []
39

40
if not common.IS_ARM64:
41
    for test_params_dicts, test_instance_class in [
42
        (sample_module.module_tests, common_nn.NewModuleTest),
43
        (sample_functional.functional_tests, common_nn.NewModuleTest),
44
        (common_nn.module_tests, common_nn.NewModuleTest),
45
        (common_nn.new_module_tests, common_nn.NewModuleTest),
46
        (common_nn.criterion_tests, common_nn.CriterionTest),
47
    ]:
48
        for test_params_dict in test_params_dicts:
49
            if test_params_dict.get("test_cpp_api_parity", True):
50
                if is_torch_nn_functional_test(test_params_dict):
51
                    functional_impl_check.write_test_to_test_class(
52
                        TestCppApiParity,
53
                        test_params_dict,
54
                        test_instance_class,
55
                        parity_table,
56
                        devices,
57
                    )
58
                else:
59
                    module_impl_check.write_test_to_test_class(
60
                        TestCppApiParity,
61
                        test_params_dict,
62
                        test_instance_class,
63
                        parity_table,
64
                        devices,
65
                    )
66
                expected_test_params_dicts.append(test_params_dict)
67

68
    # Assert that all NN module/functional test dicts appear in the parity test
69
    assert len(
70
        [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
71
    ) == len(expected_test_params_dicts) * len(devices)
72

73
    # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
74
    # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
75
    assert (
76
        len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
77
    )
78
    # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
79
    assert (
80
        len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
81
        == 4
82
    )
83

84
    module_impl_check.build_cpp_tests(
85
        TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
86
    )
87
    functional_impl_check.build_cpp_tests(
88
        TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
89
    )
90

91
if __name__ == "__main__":
92
    common.TestCase._default_dtype_check_enabled = True
93
    common.run_tests()
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.