11
from os.path import abspath, dirname, join
16
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
20
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
21
sys.path.insert(1, git_repo_path)
24
def pytest_configure(config):
25
config.option.color = "yes"
26
config.option.durations = 0
27
config.option.durations_min = 1
28
config.option.verbose = True
31
def pytest_addoption(parser):
32
parser.addoption("--torch_ver", default=None, type=str)
33
parser.addoption("--cuda_ver", default=None, type=str)
36
def validate_version(expected, found):
37
version_depth = expected.count('.') + 1
38
found = '.'.join(found.split('.')[:version_depth])
39
return found == expected
42
@pytest.fixture(scope="session", autouse=True)
43
def check_environment(pytestconfig):
44
expected_torch_version = pytestconfig.getoption("torch_ver")
45
expected_cuda_version = pytestconfig.getoption("cuda_ver")
46
if expected_torch_version is None:
48
"Running test without verifying torch version, please provide an expected torch version with --torch_ver")
49
elif not validate_version(expected_torch_version, torch.__version__):
51
f"expected torch version {expected_torch_version} did not match found torch version {torch.__version__}",
53
if expected_cuda_version is None:
55
"Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver")
56
elif not validate_version(expected_cuda_version, torch.version.cuda):
58
f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}",
64
@pytest.hookimpl(tryfirst=True)
65
def pytest_runtest_call(item):
67
if getattr(item.cls, "is_dist_test", False):
68
dist_test_class = item.cls()
69
dist_test_class(item._request)
70
item.runtest = lambda: True
76
def pytest_runtest_teardown(item, nextitem):
77
if getattr(item.cls, "reuse_dist_env", False) and not nextitem:
78
dist_test_class = item.cls()
79
for num_procs, pool in dist_test_class._pool_cache.items():
80
dist_test_class._close_pool(pool, num_procs, force=True)
83
@pytest.hookimpl(tryfirst=True)
84
def pytest_fixture_setup(fixturedef, request):
85
if getattr(fixturedef.func, "is_dist_fixture", False):
86
dist_fixture_class = fixturedef.func()
87
dist_fixture_class(request)