6
from cpp_api_parity import (
12
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
13
from cpp_api_parity.utils import is_torch_nn_functional_test
16
import torch.testing._internal.common_nn as common_nn
17
import torch.testing._internal.common_utils as common
21
PRINT_CPP_SOURCE = False
23
devices = ["cpu", "cuda"]
25
PARITY_TABLE_PATH = os.path.join(
26
os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md"
29
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
32
@torch.testing._internal.common_utils.markDynamoStrictTest
33
class TestCppApiParity(common.TestCase):
34
module_test_params_map = {}
35
functional_test_params_map = {}
38
expected_test_params_dicts = []
40
if not common.IS_ARM64:
41
for test_params_dicts, test_instance_class in [
42
(sample_module.module_tests, common_nn.NewModuleTest),
43
(sample_functional.functional_tests, common_nn.NewModuleTest),
44
(common_nn.module_tests, common_nn.NewModuleTest),
45
(common_nn.new_module_tests, common_nn.NewModuleTest),
46
(common_nn.criterion_tests, common_nn.CriterionTest),
48
for test_params_dict in test_params_dicts:
49
if test_params_dict.get("test_cpp_api_parity", True):
50
if is_torch_nn_functional_test(test_params_dict):
51
functional_impl_check.write_test_to_test_class(
59
module_impl_check.write_test_to_test_class(
66
expected_test_params_dicts.append(test_params_dict)
70
[name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
71
) == len(expected_test_params_dicts) * len(devices)
76
len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
80
len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
84
module_impl_check.build_cpp_tests(
85
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
87
functional_impl_check.build_cpp_tests(
88
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
91
if __name__ == "__main__":
92
common.TestCase._default_dtype_check_enabled = True