pytorch

Форк
0
/
test_distributed_spawn.py 
59 строк · 1.9 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import os
4
import sys
5

6
import torch
7
import torch.distributed as dist
8
from os import path
9

10
torch.backends.cuda.matmul.allow_tf32 = False
11

12
if not dist.is_available():
13
    print("Distributed not available, skipping tests", file=sys.stderr)
14
    sys.exit(0)
15

16
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN, NO_MULTIPROCESSING_SPAWN
17
from torch.testing._internal.distributed.distributed_test import (
18
    DistributedTest, TestDistBackend
19
)
20

21
if TEST_WITH_DEV_DBG_ASAN:
22
    print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
23
    sys.exit(0)
24

25
if NO_MULTIPROCESSING_SPAWN:
26
    print("Spawn not available, skipping tests.", file=sys.stderr)
27
    sys.exit(0)
28

29
_allowed_backends = ("gloo", "nccl", "ucc")
30
if (
31
    "BACKEND" not in os.environ
32
    or "WORLD_SIZE" not in os.environ
33
    or "TEMP_DIR" not in os.environ
34
    or not path.exists(path.join(os.environ["TEMP_DIR"], "barrier"))
35
):
36
    # TODO can we actually have `run_tests.py` emit the complete instructions when it prints a repro command?
37
    raise RuntimeError(
38
        "Missing expected env vars for `test_distributed_spawn.py`.  Please ensure to specify the following:\n"
39
        f"'BACKEND' = one of {_allowed_backends}\n"
40
        f"'WORLD_SIZE' = int >= 2\n"
41
        "'TEMP_DIR' specifying a directory containing a barrier file named 'barrier'.\n\n"
42
        f"e.g.\ntouch /tmp/barrier && TEMP_DIR=/tmp BACKEND='nccl' WORLD_SIZE=2 python {__file__}",
43
    )
44

45
BACKEND = os.environ["BACKEND"]
46

47
if BACKEND in _allowed_backends:
48
    class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
49

50
        def setUp(self):
51
            super().setUp()
52
            self._spawn_processes()
53
            torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
54
else:
55
    print(f"Invalid backend {BACKEND}. Tests will not be run!")
56

57

58
if __name__ == "__main__":
59
    run_tests()
60

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.