5
def is_global_master(args):
9
def is_local_master(args):
10
return args.local_rank == 0
13
def is_master(args, local=False):
14
return is_local_master(args) if local else is_global_master(args)
17
def is_using_distributed():
18
if "WORLD_SIZE" in os.environ:
19
return int(os.environ["WORLD_SIZE"]) > 1
20
if "SLURM_NTASKS" in os.environ:
21
return int(os.environ["SLURM_NTASKS"]) > 1
25
def world_info_from_env():
31
"OMPI_COMM_WORLD_LOCAL_RANK",
34
local_rank = int(os.environ[v])
37
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
39
global_rank = int(os.environ[v])
42
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
44
world_size = int(os.environ[v])
46
return local_rank, global_rank, world_size
49
def init_distributed_device(args):
50
# Distributed training = training on more than one GPU.
51
# Works in both single and multi-node scenarios.
52
args.distributed = False
54
args.rank = 0 # global rank
56
if is_using_distributed():
57
if "SLURM_PROCID" in os.environ:
59
args.local_rank, args.rank, args.world_size = world_info_from_env()
60
# SLURM var -> torch.distributed vars in case needed
61
os.environ["LOCAL_RANK"] = str(args.local_rank)
62
os.environ["RANK"] = str(args.rank)
63
os.environ["WORLD_SIZE"] = str(args.world_size)
64
torch.distributed.init_process_group(
65
backend=args.dist_backend,
66
init_method=args.dist_url,
67
world_size=args.world_size,
71
# DDP via torchrun, torch.distributed.launch
72
args.local_rank, _, _ = world_info_from_env()
73
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url)
74
args.world_size = torch.distributed.get_world_size()
75
args.rank = torch.distributed.get_rank()
76
args.distributed = True
78
# needed to run on single gpu
79
torch.distributed.init_process_group(
80
backend=args.dist_backend,
81
init_method=args.dist_url,
86
if torch.cuda.is_available():
87
if args.distributed and not args.no_set_device_rank:
88
device = "cuda:%d" % args.local_rank
91
torch.cuda.set_device(device)
95
device = torch.device(device)