pytorch
292 строки · 9.3 Кб
1#!/usr/bin/env python3
2#
3# Measure distributed training iteration time.
4#
5# This program performs a sweep over a) a number of model architectures, and
6# b) an increasing number of processes. This produces a 1-GPU baseline,
7# an 8-GPU baseline (if applicable), as well as measurements for however
8# many processes can participate in training.
9#
10
11import argparse12import itertools13import json14import os15import shlex16import subprocess17import sys18import time19
20import numpy as np21import torchvision22
23import torch24import torch.distributed as dist25import torch.nn as nn26import torch.optim as optim27
28
29def allgather_object(obj):30out = [None for _ in range(dist.get_world_size())]31dist.all_gather_object(out, obj)32return out33
34
35def allgather_run(cmd):36proc = subprocess.run(shlex.split(cmd), capture_output=True)37assert proc.returncode == 038return allgather_object(proc.stdout.decode("utf-8"))39
40
41def allequal(iterator):42iterator = iter(iterator)43try:44first = next(iterator)45except StopIteration:46return True47return all(first == rest for rest in iterator)48
49
50def benchmark_process_group(pg, benchmark, use_ddp_for_single_rank=True):51torch.manual_seed(pg.rank())52torch.cuda.manual_seed(pg.rank())53
54model = benchmark.create_model()55data = [(benchmark.generate_inputs(), benchmark.generate_target())]56criterion = nn.CrossEntropyLoss()57optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=1e-4)58if use_ddp_for_single_rank or pg.size() > 1:59model = torch.nn.parallel.DistributedDataParallel(60model,61device_ids=[torch.cuda.current_device()],62broadcast_buffers=False,63process_group=pg,64bucket_cap_mb=benchmark.bucket_size,65)66
67measurements = []68warmup_iterations = 569measured_iterations = 1070for inputs, target in data * (warmup_iterations + measured_iterations):71start = time.time()72output = model(*inputs)73loss = criterion(output, target)74loss.backward()75optimizer.step()76torch.cuda.synchronize()77measurements.append(time.time() - start)78
79# Throw away measurements for warmup iterations80return measurements[warmup_iterations:]81
82
83def run_benchmark(benchmark, ranks, opts):84group = dist.new_group(ranks=ranks, backend=benchmark.distributed_backend)85measurements = []86if dist.get_rank() in set(ranks):87if not opts:88opts = {}89measurements = benchmark_process_group(group, benchmark, **opts)90dist.destroy_process_group(group)91dist.barrier()92
93# Aggregate measurements for better estimation of percentiles94return list(itertools.chain(*allgather_object(measurements)))95
96
97def sweep(benchmark):98# Synthesize the set of benchmarks to run.99# This list contain tuples for ("string prefix", [rank...]).100benchmarks = []101
102def append_benchmark(prefix, ranks, opts=None):103prefix = f"{len(ranks):4} GPUs -- {prefix}"104benchmarks.append((prefix, ranks, opts))105
106def local_print(msg):107if dist.get_rank() == 0:108print(msg, end="", flush=True) # noqa: E999109
110def print_header():111local_print("\n")112local_print("%22s" % "")113for p in [50, 75, 90, 95]:114local_print("%14s%10s" % ("sec/iter", "ex/sec"))115local_print("\n")116
117def print_measurements(prefix, nelem, measurements):118measurements = sorted(measurements)119local_print("%8s:" % prefix)120for p in [50, 75, 90, 95]:121v = np.percentile(measurements, p)122local_print(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v))123local_print("\n")124
125# Every process runs once by themselves to warm up (CUDA init, etc).126append_benchmark(" warmup", [dist.get_rank()], {"use_ddp_for_single_rank": False})127
128# Single machine baselines129append_benchmark(" no ddp", range(1), {"use_ddp_for_single_rank": False})130append_benchmark(" 1M/1G", range(1))131append_benchmark(" 1M/2G", range(2))132append_benchmark(" 1M/4G", range(4))133
134# Multi-machine benchmarks135for i in range(1, (dist.get_world_size() // 8) + 1):136append_benchmark(" %dM/8G" % i, range(i * 8))137
138# Run benchmarks in order of increasing number of GPUs139print_header()140results = []141for prefix, ranks, opts in sorted(benchmarks, key=lambda tup: len(tup[1])):142# Turn range into materialized list.143ranks = list(ranks)144measurements = run_benchmark(benchmark, ranks, opts)145if "warmup" not in prefix:146print_measurements(prefix, benchmark.batch_size, measurements)147results.append({"ranks": ranks, "measurements": measurements})148
149return results150
151
152class Benchmark:153def __init__(self, device, distributed_backend, bucket_size):154self.device = device155self.batch_size = 32156self.distributed_backend = distributed_backend157self.bucket_size = bucket_size158
159def __str__(self):160raise NotImplementedError161
162def create_model(self):163raise NotImplementedError164
165def generate_inputs(self):166raise NotImplementedError167
168def generate_target(self):169raise NotImplementedError170
171
172class TorchvisionBenchmark(Benchmark):173def __init__(self, device, distributed_backend, bucket_size, model):174super().__init__(175device,176distributed_backend,177bucket_size,178)179self.model = model180
181def __str__(self):182return f"{self.model} with batch size {self.batch_size}"183
184def create_model(self):185return torchvision.models.__dict__[self.model]().to(self.device)186
187def generate_inputs(self):188return [torch.rand([self.batch_size, 3, 224, 224], device=self.device)]189
190def generate_target(self):191return torch.tensor([1] * self.batch_size, dtype=torch.long, device=self.device)192
193
194def main():195parser = argparse.ArgumentParser(description="PyTorch distributed benchmark suite")196parser.add_argument("--rank", type=int, default=os.environ["RANK"])197parser.add_argument("--world-size", type=int, required=True)198parser.add_argument("--distributed-backend", type=str, default="nccl")199parser.add_argument("--bucket-size", type=int, default=25)200parser.add_argument("--master-addr", type=str, required=True)201parser.add_argument("--master-port", type=str, required=True)202parser.add_argument("--model", type=str)203parser.add_argument(204"--json", type=str, metavar="PATH", help="Write file with benchmark results"205)206args = parser.parse_args()207
208num_gpus_per_node = torch.cuda.device_count()209assert num_gpus_per_node == 8, "Expected 8 GPUs per machine"210
211# The global process group used only for communicating benchmark212# metadata, like measurements. Not for benchmarking itself.213dist.init_process_group(214backend="gloo",215init_method=f"tcp://{args.master_addr}:{args.master_port}",216rank=args.rank,217world_size=args.world_size,218)219
220output = allgather_run("nvidia-smi topo -m")221if not allequal(output):222print('Output of "nvidia-smi topo -m" differs between machines')223sys.exit(1)224
225if args.rank == 0:226print("-----------------------------------")227print("PyTorch distributed benchmark suite")228print("-----------------------------------")229print()230print(f"* PyTorch version: {torch.__version__}")231print(f"* CUDA version: {torch.version.cuda}")232print(f"* Distributed backend: {args.distributed_backend}")233print(f"* Maximum bucket size: {args.bucket_size}MB")234print()235print("--- nvidia-smi topo -m ---")236print()237print(output[0])238print("--------------------------")239print()240
241torch.cuda.set_device(dist.get_rank() % 8)242device = torch.device("cuda:%d" % (dist.get_rank() % 8))243
244benchmarks = []245if args.model:246benchmarks.append(247TorchvisionBenchmark(248device=device,249distributed_backend=args.distributed_backend,250bucket_size=args.bucket_size,251model=args.model,252)253)254else:255for model in ["resnet50", "resnet101", "resnext50_32x4d", "resnext101_32x8d"]:256benchmarks.append(257TorchvisionBenchmark(258device=device,259distributed_backend=args.distributed_backend,260bucket_size=args.bucket_size,261model=model,262)263)264
265benchmark_results = []266for benchmark in benchmarks:267if args.rank == 0:268print(f"\nBenchmark: {str(benchmark)}")269result = sweep(benchmark)270benchmark_results.append(271{272"model": benchmark.model,273"batch_size": benchmark.batch_size,274"result": result,275}276)277
278# Write file with benchmark results if applicable279if args.rank == 0 and args.json:280report = {281"pytorch_version": torch.__version__,282"cuda_version": torch.version.cuda,283"distributed_backend": args.distributed_backend,284"bucket_size": args.bucket_size,285"benchmark_results": benchmark_results,286}287with open(args.json, "w") as f:288json.dump(report, f)289
290
291if __name__ == "__main__":292main()293