7
import torch.testing._internal.common_utils as common
8
import torch.testing._internal.common_nn as common_nn
9
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
10
from cpp_api_parity.utils import is_torch_nn_functional_test
11
from cpp_api_parity import module_impl_check, functional_impl_check, sample_module, sample_functional
14
PRINT_CPP_SOURCE = False
16
devices = ['cpu', 'cuda']
18
PARITY_TABLE_PATH = os.path.join(os.path.dirname(__file__), 'cpp_api_parity', 'parity-tracker.md')
20
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
22
@torch.testing._internal.common_utils.markDynamoStrictTest
23
class TestCppApiParity(common.TestCase):
24
module_test_params_map = {}
25
functional_test_params_map = {}
27
expected_test_params_dicts = []
29
if not common.IS_ARM64:
30
for test_params_dicts, test_instance_class in [
31
(sample_module.module_tests, common_nn.NewModuleTest),
32
(sample_functional.functional_tests, common_nn.NewModuleTest),
33
(common_nn.module_tests, common_nn.NewModuleTest),
34
(common_nn.new_module_tests, common_nn.NewModuleTest),
35
(common_nn.criterion_tests, common_nn.CriterionTest),
37
for test_params_dict in test_params_dicts:
38
if test_params_dict.get('test_cpp_api_parity', True):
39
if is_torch_nn_functional_test(test_params_dict):
40
functional_impl_check.write_test_to_test_class(
41
TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
43
module_impl_check.write_test_to_test_class(
44
TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
45
expected_test_params_dicts.append(test_params_dict)
48
assert len([name for name in TestCppApiParity.__dict__ if 'test_torch_nn_' in name]) == \
49
len(expected_test_params_dicts) * len(devices)
53
assert len([name for name in TestCppApiParity.__dict__ if 'SampleModule' in name]) == 4
55
assert len([name for name in TestCppApiParity.__dict__ if 'sample_functional' in name]) == 4
57
module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
58
functional_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
60
if __name__ == "__main__":
61
common.TestCase._default_dtype_check_enabled = True