otter

Форк
0
/
distributed.py 
96 строк · 2.9 Кб
1
import os
2
import torch
3

4

5
def is_global_master(args):
6
    return args.rank == 0
7

8

9
def is_local_master(args):
10
    return args.local_rank == 0
11

12

13
def is_master(args, local=False):
14
    return is_local_master(args) if local else is_global_master(args)
15

16

17
def is_using_distributed():
18
    if "WORLD_SIZE" in os.environ:
19
        return int(os.environ["WORLD_SIZE"]) > 1
20
    if "SLURM_NTASKS" in os.environ:
21
        return int(os.environ["SLURM_NTASKS"]) > 1
22
    return False
23

24

25
def world_info_from_env():
26
    local_rank = 0
27
    for v in (
28
        "LOCAL_RANK",
29
        "MPI_LOCALRANKID",
30
        "SLURM_LOCALID",
31
        "OMPI_COMM_WORLD_LOCAL_RANK",
32
    ):
33
        if v in os.environ:
34
            local_rank = int(os.environ[v])
35
            break
36
    global_rank = 0
37
    for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
38
        if v in os.environ:
39
            global_rank = int(os.environ[v])
40
            break
41
    world_size = 1
42
    for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
43
        if v in os.environ:
44
            world_size = int(os.environ[v])
45
            break
46
    return local_rank, global_rank, world_size
47

48

49
def init_distributed_device(args):
50
    # Distributed training = training on more than one GPU.
51
    # Works in both single and multi-node scenarios.
52
    args.distributed = False
53
    args.world_size = 1
54
    args.rank = 0  # global rank
55
    args.local_rank = 0
56
    if is_using_distributed():
57
        if "SLURM_PROCID" in os.environ:
58
            # DDP via SLURM
59
            args.local_rank, args.rank, args.world_size = world_info_from_env()
60
            # SLURM var -> torch.distributed vars in case needed
61
            os.environ["LOCAL_RANK"] = str(args.local_rank)
62
            os.environ["RANK"] = str(args.rank)
63
            os.environ["WORLD_SIZE"] = str(args.world_size)
64
            torch.distributed.init_process_group(
65
                backend=args.dist_backend,
66
                init_method=args.dist_url,
67
                world_size=args.world_size,
68
                rank=args.rank,
69
            )
70
        else:
71
            # DDP via torchrun, torch.distributed.launch
72
            args.local_rank, _, _ = world_info_from_env()
73
            torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url)
74
            args.world_size = torch.distributed.get_world_size()
75
            args.rank = torch.distributed.get_rank()
76
        args.distributed = True
77
    else:
78
        # needed to run on single gpu
79
        torch.distributed.init_process_group(
80
            backend=args.dist_backend,
81
            init_method=args.dist_url,
82
            world_size=1,
83
            rank=0,
84
        )
85

86
    if torch.cuda.is_available():
87
        if args.distributed and not args.no_set_device_rank:
88
            device = "cuda:%d" % args.local_rank
89
        else:
90
            device = "cuda:0"
91
        torch.cuda.set_device(device)
92
    else:
93
        device = "cpu"
94
    args.device = device
95
    device = torch.device(device)
96
    return device
97

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.