pytorch

Форк
0
/
test_cpp_api_parity.py 
62 строки · 2.8 Кб
1
# Owner(s): ["module: cpp"]
2

3

4
import os
5

6
import torch
7
import torch.testing._internal.common_utils as common
8
import torch.testing._internal.common_nn as common_nn
9
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
10
from cpp_api_parity.utils import is_torch_nn_functional_test
11
from cpp_api_parity import module_impl_check, functional_impl_check, sample_module, sample_functional
12

13
# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
14
PRINT_CPP_SOURCE = False
15

16
devices = ['cpu', 'cuda']
17

18
PARITY_TABLE_PATH = os.path.join(os.path.dirname(__file__), 'cpp_api_parity', 'parity-tracker.md')
19

20
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
21

22
@torch.testing._internal.common_utils.markDynamoStrictTest
23
class TestCppApiParity(common.TestCase):
24
    module_test_params_map = {}
25
    functional_test_params_map = {}
26

27
expected_test_params_dicts = []
28

29
if not common.IS_ARM64:
30
    for test_params_dicts, test_instance_class in [
31
        (sample_module.module_tests, common_nn.NewModuleTest),
32
        (sample_functional.functional_tests, common_nn.NewModuleTest),
33
        (common_nn.module_tests, common_nn.NewModuleTest),
34
        (common_nn.new_module_tests, common_nn.NewModuleTest),
35
        (common_nn.criterion_tests, common_nn.CriterionTest),
36
    ]:
37
        for test_params_dict in test_params_dicts:
38
            if test_params_dict.get('test_cpp_api_parity', True):
39
                if is_torch_nn_functional_test(test_params_dict):
40
                    functional_impl_check.write_test_to_test_class(
41
                        TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
42
                else:
43
                    module_impl_check.write_test_to_test_class(
44
                        TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
45
                expected_test_params_dicts.append(test_params_dict)
46

47
    # Assert that all NN module/functional test dicts appear in the parity test
48
    assert len([name for name in TestCppApiParity.__dict__ if 'test_torch_nn_' in name]) == \
49
        len(expected_test_params_dicts) * len(devices)
50

51
    # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
52
    # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
53
    assert len([name for name in TestCppApiParity.__dict__ if 'SampleModule' in name]) == 4
54
    # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
55
    assert len([name for name in TestCppApiParity.__dict__ if 'sample_functional' in name]) == 4
56

57
    module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
58
    functional_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
59

60
if __name__ == "__main__":
61
    common.TestCase._default_dtype_check_enabled = True
62
    common.run_tests()
63

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.