pytorch
291 строка · 9.3 Кб
1#!/usr/bin/env python3
2#
3# Measure distributed training iteration time.
4#
5# This program performs a sweep over a) a number of model architectures, and
6# b) an increasing number of processes. This produces a 1-GPU baseline,
7# an 8-GPU baseline (if applicable), as well as measurements for however
8# many processes can participate in training.
9#
10
11import argparse12import itertools13import json14import os15import shlex16import subprocess17import sys18import time19
20import numpy as np21import torch22import torch.distributed as dist23import torch.nn as nn24import torch.optim as optim25import torchvision26
27
28def allgather_object(obj):29out = [None for _ in range(dist.get_world_size())]30dist.all_gather_object(out, obj)31return out32
33
34def allgather_run(cmd):35proc = subprocess.run(shlex.split(cmd), capture_output=True)36assert proc.returncode == 037return allgather_object(proc.stdout.decode("utf-8"))38
39
40def allequal(iterator):41iterator = iter(iterator)42try:43first = next(iterator)44except StopIteration:45return True46return all(first == rest for rest in iterator)47
48
49def benchmark_process_group(pg, benchmark, use_ddp_for_single_rank=True):50torch.manual_seed(pg.rank())51torch.cuda.manual_seed(pg.rank())52
53model = benchmark.create_model()54data = [(benchmark.generate_inputs(), benchmark.generate_target())]55criterion = nn.CrossEntropyLoss()56optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=1e-4)57if use_ddp_for_single_rank or pg.size() > 1:58model = torch.nn.parallel.DistributedDataParallel(59model,60device_ids=[torch.cuda.current_device()],61broadcast_buffers=False,62process_group=pg,63bucket_cap_mb=benchmark.bucket_size,64)65
66measurements = []67warmup_iterations = 568measured_iterations = 1069for inputs, target in data * (warmup_iterations + measured_iterations):70start = time.time()71output = model(*inputs)72loss = criterion(output, target)73loss.backward()74optimizer.step()75torch.cuda.synchronize()76measurements.append(time.time() - start)77
78# Throw away measurements for warmup iterations79return measurements[warmup_iterations:]80
81
82def run_benchmark(benchmark, ranks, opts):83group = dist.new_group(ranks=ranks, backend=benchmark.distributed_backend)84measurements = []85if dist.get_rank() in set(ranks):86if not opts:87opts = {}88measurements = benchmark_process_group(group, benchmark, **opts)89dist.destroy_process_group(group)90dist.barrier()91
92# Aggregate measurements for better estimation of percentiles93return list(itertools.chain(*allgather_object(measurements)))94
95
96def sweep(benchmark):97# Synthesize the set of benchmarks to run.98# This list contain tuples for ("string prefix", [rank...]).99benchmarks = []100
101def append_benchmark(prefix, ranks, opts=None):102prefix = f"{len(ranks):4} GPUs -- {prefix}"103benchmarks.append((prefix, ranks, opts))104
105def local_print(msg):106if dist.get_rank() == 0:107print(msg, end="", flush=True) # noqa: E999108
109def print_header():110local_print("\n")111local_print("%22s" % "")112for p in [50, 75, 90, 95]:113local_print("%14s%10s" % ("sec/iter", "ex/sec"))114local_print("\n")115
116def print_measurements(prefix, nelem, measurements):117measurements = sorted(measurements)118local_print("%8s:" % prefix)119for p in [50, 75, 90, 95]:120v = np.percentile(measurements, p)121local_print(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v))122local_print("\n")123
124# Every process runs once by themselves to warm up (CUDA init, etc).125append_benchmark(" warmup", [dist.get_rank()], {"use_ddp_for_single_rank": False})126
127# Single machine baselines128append_benchmark(" no ddp", range(1), {"use_ddp_for_single_rank": False})129append_benchmark(" 1M/1G", range(1))130append_benchmark(" 1M/2G", range(2))131append_benchmark(" 1M/4G", range(4))132
133# Multi-machine benchmarks134for i in range(1, (dist.get_world_size() // 8) + 1):135append_benchmark(" %dM/8G" % i, range(i * 8))136
137# Run benchmarks in order of increasing number of GPUs138print_header()139results = []140for prefix, ranks, opts in sorted(benchmarks, key=lambda tup: len(tup[1])):141# Turn range into materialized list.142ranks = list(ranks)143measurements = run_benchmark(benchmark, ranks, opts)144if "warmup" not in prefix:145print_measurements(prefix, benchmark.batch_size, measurements)146results.append({"ranks": ranks, "measurements": measurements})147
148return results149
150
151class Benchmark:152def __init__(self, device, distributed_backend, bucket_size):153self.device = device154self.batch_size = 32155self.distributed_backend = distributed_backend156self.bucket_size = bucket_size157
158def __str__(self):159raise NotImplementedError160
161def create_model(self):162raise NotImplementedError163
164def generate_inputs(self):165raise NotImplementedError166
167def generate_target(self):168raise NotImplementedError169
170
171class TorchvisionBenchmark(Benchmark):172def __init__(self, device, distributed_backend, bucket_size, model):173super().__init__(174device,175distributed_backend,176bucket_size,177)178self.model = model179
180def __str__(self):181return f"{self.model} with batch size {self.batch_size}"182
183def create_model(self):184return torchvision.models.__dict__[self.model]().to(self.device)185
186def generate_inputs(self):187return [torch.rand([self.batch_size, 3, 224, 224], device=self.device)]188
189def generate_target(self):190return torch.tensor([1] * self.batch_size, dtype=torch.long, device=self.device)191
192
193def main():194parser = argparse.ArgumentParser(description="PyTorch distributed benchmark suite")195parser.add_argument("--rank", type=int, default=os.environ["RANK"])196parser.add_argument("--world-size", type=int, required=True)197parser.add_argument("--distributed-backend", type=str, default="nccl")198parser.add_argument("--bucket-size", type=int, default=25)199parser.add_argument("--master-addr", type=str, required=True)200parser.add_argument("--master-port", type=str, required=True)201parser.add_argument("--model", type=str)202parser.add_argument(203"--json", type=str, metavar="PATH", help="Write file with benchmark results"204)205args = parser.parse_args()206
207num_gpus_per_node = torch.cuda.device_count()208assert num_gpus_per_node == 8, "Expected 8 GPUs per machine"209
210# The global process group used only for communicating benchmark211# metadata, like measurements. Not for benchmarking itself.212dist.init_process_group(213backend="gloo",214init_method=f"tcp://{args.master_addr}:{args.master_port}",215rank=args.rank,216world_size=args.world_size,217)218
219output = allgather_run("nvidia-smi topo -m")220if not allequal(output):221print('Output of "nvidia-smi topo -m" differs between machines')222sys.exit(1)223
224if args.rank == 0:225print("-----------------------------------")226print("PyTorch distributed benchmark suite")227print("-----------------------------------")228print("")229print(f"* PyTorch version: {torch.__version__}")230print(f"* CUDA version: {torch.version.cuda}")231print(f"* Distributed backend: {args.distributed_backend}")232print(f"* Maximum bucket size: {args.bucket_size}MB")233print("")234print("--- nvidia-smi topo -m ---")235print("")236print(output[0])237print("--------------------------")238print("")239
240torch.cuda.set_device(dist.get_rank() % 8)241device = torch.device("cuda:%d" % (dist.get_rank() % 8))242
243benchmarks = []244if args.model:245benchmarks.append(246TorchvisionBenchmark(247device=device,248distributed_backend=args.distributed_backend,249bucket_size=args.bucket_size,250model=args.model,251)252)253else:254for model in ["resnet50", "resnet101", "resnext50_32x4d", "resnext101_32x8d"]:255benchmarks.append(256TorchvisionBenchmark(257device=device,258distributed_backend=args.distributed_backend,259bucket_size=args.bucket_size,260model=model,261)262)263
264benchmark_results = []265for benchmark in benchmarks:266if args.rank == 0:267print(f"\nBenchmark: {str(benchmark)}")268result = sweep(benchmark)269benchmark_results.append(270{271"model": benchmark.model,272"batch_size": benchmark.batch_size,273"result": result,274}275)276
277# Write file with benchmark results if applicable278if args.rank == 0 and args.json:279report = {280"pytorch_version": torch.__version__,281"cuda_version": torch.version.cuda,282"distributed_backend": args.distributed_backend,283"bucket_size": args.bucket_size,284"benchmark_results": benchmark_results,285}286with open(args.json, "w") as f:287json.dump(report, f)288
289
290if __name__ == "__main__":291main()292