deepspeed

Форк
0
/
conftest.py 
87 строк · 3.3 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
# tests directory-specific settings - this file is run automatically by pytest before any tests are run
7

8
import sys
9
import pytest
10
import os
11
from os.path import abspath, dirname, join
12
import torch
13
import warnings
14

15
# Set this environment variable for the T5 inference unittest(s) (e.g. google/t5-v1_1-small)
16
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
17

18
# allow having multiple repository checkouts and not needing to remember to rerun
19
# 'pip install -e .[dev]' when switching between checkouts and running tests.
20
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
21
sys.path.insert(1, git_repo_path)
22

23

24
def pytest_configure(config):
25
    config.option.color = "yes"
26
    config.option.durations = 0
27
    config.option.durations_min = 1
28
    config.option.verbose = True
29

30

31
def pytest_addoption(parser):
32
    parser.addoption("--torch_ver", default=None, type=str)
33
    parser.addoption("--cuda_ver", default=None, type=str)
34

35

36
def validate_version(expected, found):
37
    version_depth = expected.count('.') + 1
38
    found = '.'.join(found.split('.')[:version_depth])
39
    return found == expected
40

41

42
@pytest.fixture(scope="session", autouse=True)
43
def check_environment(pytestconfig):
44
    expected_torch_version = pytestconfig.getoption("torch_ver")
45
    expected_cuda_version = pytestconfig.getoption("cuda_ver")
46
    if expected_torch_version is None:
47
        warnings.warn(
48
            "Running test without verifying torch version, please provide an expected torch version with --torch_ver")
49
    elif not validate_version(expected_torch_version, torch.__version__):
50
        pytest.exit(
51
            f"expected torch version {expected_torch_version} did not match found torch version {torch.__version__}",
52
            returncode=2)
53
    if expected_cuda_version is None:
54
        warnings.warn(
55
            "Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver")
56
    elif not validate_version(expected_cuda_version, torch.version.cuda):
57
        pytest.exit(
58
            f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}",
59
            returncode=2)
60

61

62
# Override of pytest "runtest" for DistributedTest class
63
# This hook is run before the default pytest_runtest_call
64
@pytest.hookimpl(tryfirst=True)
65
def pytest_runtest_call(item):
66
    # We want to use our own launching function for distributed tests
67
    if getattr(item.cls, "is_dist_test", False):
68
        dist_test_class = item.cls()
69
        dist_test_class(item._request)
70
        item.runtest = lambda: True  # Dummy function so test is not run twice
71

72

73
# We allow DistributedTest to reuse distributed environments. When the last
74
# test for a class is run, we want to make sure those distributed environments
75
# are destroyed.
76
def pytest_runtest_teardown(item, nextitem):
77
    if getattr(item.cls, "reuse_dist_env", False) and not nextitem:
78
        dist_test_class = item.cls()
79
        for num_procs, pool in dist_test_class._pool_cache.items():
80
            dist_test_class._close_pool(pool, num_procs, force=True)
81

82

83
@pytest.hookimpl(tryfirst=True)
84
def pytest_fixture_setup(fixturedef, request):
85
    if getattr(fixturedef.func, "is_dist_fixture", False):
86
        dist_fixture_class = fixturedef.func()
87
        dist_fixture_class(request)
88

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.