pytorch

Форк
0
291 строка · 9.3 Кб
1
#!/usr/bin/env python3
2
#
3
# Measure distributed training iteration time.
4
#
5
# This program performs a sweep over a) a number of model architectures, and
6
# b) an increasing number of processes. This produces a 1-GPU baseline,
7
# an 8-GPU baseline (if applicable), as well as measurements for however
8
# many processes can participate in training.
9
#
10

11
import argparse
12
import itertools
13
import json
14
import os
15
import shlex
16
import subprocess
17
import sys
18
import time
19

20
import numpy as np
21
import torch
22
import torch.distributed as dist
23
import torch.nn as nn
24
import torch.optim as optim
25
import torchvision
26

27

28
def allgather_object(obj):
29
    out = [None for _ in range(dist.get_world_size())]
30
    dist.all_gather_object(out, obj)
31
    return out
32

33

34
def allgather_run(cmd):
35
    proc = subprocess.run(shlex.split(cmd), capture_output=True)
36
    assert proc.returncode == 0
37
    return allgather_object(proc.stdout.decode("utf-8"))
38

39

40
def allequal(iterator):
41
    iterator = iter(iterator)
42
    try:
43
        first = next(iterator)
44
    except StopIteration:
45
        return True
46
    return all(first == rest for rest in iterator)
47

48

49
def benchmark_process_group(pg, benchmark, use_ddp_for_single_rank=True):
50
    torch.manual_seed(pg.rank())
51
    torch.cuda.manual_seed(pg.rank())
52

53
    model = benchmark.create_model()
54
    data = [(benchmark.generate_inputs(), benchmark.generate_target())]
55
    criterion = nn.CrossEntropyLoss()
56
    optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=1e-4)
57
    if use_ddp_for_single_rank or pg.size() > 1:
58
        model = torch.nn.parallel.DistributedDataParallel(
59
            model,
60
            device_ids=[torch.cuda.current_device()],
61
            broadcast_buffers=False,
62
            process_group=pg,
63
            bucket_cap_mb=benchmark.bucket_size,
64
        )
65

66
    measurements = []
67
    warmup_iterations = 5
68
    measured_iterations = 10
69
    for inputs, target in data * (warmup_iterations + measured_iterations):
70
        start = time.time()
71
        output = model(*inputs)
72
        loss = criterion(output, target)
73
        loss.backward()
74
        optimizer.step()
75
        torch.cuda.synchronize()
76
        measurements.append(time.time() - start)
77

78
    # Throw away measurements for warmup iterations
79
    return measurements[warmup_iterations:]
80

81

82
def run_benchmark(benchmark, ranks, opts):
83
    group = dist.new_group(ranks=ranks, backend=benchmark.distributed_backend)
84
    measurements = []
85
    if dist.get_rank() in set(ranks):
86
        if not opts:
87
            opts = {}
88
        measurements = benchmark_process_group(group, benchmark, **opts)
89
    dist.destroy_process_group(group)
90
    dist.barrier()
91

92
    # Aggregate measurements for better estimation of percentiles
93
    return list(itertools.chain(*allgather_object(measurements)))
94

95

96
def sweep(benchmark):
97
    # Synthesize the set of benchmarks to run.
98
    # This list contain tuples for ("string prefix", [rank...]).
99
    benchmarks = []
100

101
    def append_benchmark(prefix, ranks, opts=None):
102
        prefix = f"{len(ranks):4} GPUs -- {prefix}"
103
        benchmarks.append((prefix, ranks, opts))
104

105
    def local_print(msg):
106
        if dist.get_rank() == 0:
107
            print(msg, end="", flush=True)  # noqa: E999
108

109
    def print_header():
110
        local_print("\n")
111
        local_print("%22s" % "")
112
        for p in [50, 75, 90, 95]:
113
            local_print("%14s%10s" % ("sec/iter", "ex/sec"))
114
        local_print("\n")
115

116
    def print_measurements(prefix, nelem, measurements):
117
        measurements = sorted(measurements)
118
        local_print("%8s:" % prefix)
119
        for p in [50, 75, 90, 95]:
120
            v = np.percentile(measurements, p)
121
            local_print("  p%02d:  %1.3fs  %6d/s" % (p, v, nelem / v))
122
        local_print("\n")
123

124
    # Every process runs once by themselves to warm up (CUDA init, etc).
125
    append_benchmark("  warmup", [dist.get_rank()], {"use_ddp_for_single_rank": False})
126

127
    # Single machine baselines
128
    append_benchmark("  no ddp", range(1), {"use_ddp_for_single_rank": False})
129
    append_benchmark("   1M/1G", range(1))
130
    append_benchmark("   1M/2G", range(2))
131
    append_benchmark("   1M/4G", range(4))
132

133
    # Multi-machine benchmarks
134
    for i in range(1, (dist.get_world_size() // 8) + 1):
135
        append_benchmark("   %dM/8G" % i, range(i * 8))
136

137
    # Run benchmarks in order of increasing number of GPUs
138
    print_header()
139
    results = []
140
    for prefix, ranks, opts in sorted(benchmarks, key=lambda tup: len(tup[1])):
141
        # Turn range into materialized list.
142
        ranks = list(ranks)
143
        measurements = run_benchmark(benchmark, ranks, opts)
144
        if "warmup" not in prefix:
145
            print_measurements(prefix, benchmark.batch_size, measurements)
146
            results.append({"ranks": ranks, "measurements": measurements})
147

148
    return results
149

150

151
class Benchmark:
152
    def __init__(self, device, distributed_backend, bucket_size):
153
        self.device = device
154
        self.batch_size = 32
155
        self.distributed_backend = distributed_backend
156
        self.bucket_size = bucket_size
157

158
    def __str__(self):
159
        raise NotImplementedError
160

161
    def create_model(self):
162
        raise NotImplementedError
163

164
    def generate_inputs(self):
165
        raise NotImplementedError
166

167
    def generate_target(self):
168
        raise NotImplementedError
169

170

171
class TorchvisionBenchmark(Benchmark):
172
    def __init__(self, device, distributed_backend, bucket_size, model):
173
        super().__init__(
174
            device,
175
            distributed_backend,
176
            bucket_size,
177
        )
178
        self.model = model
179

180
    def __str__(self):
181
        return f"{self.model} with batch size {self.batch_size}"
182

183
    def create_model(self):
184
        return torchvision.models.__dict__[self.model]().to(self.device)
185

186
    def generate_inputs(self):
187
        return [torch.rand([self.batch_size, 3, 224, 224], device=self.device)]
188

189
    def generate_target(self):
190
        return torch.tensor([1] * self.batch_size, dtype=torch.long, device=self.device)
191

192

193
def main():
194
    parser = argparse.ArgumentParser(description="PyTorch distributed benchmark suite")
195
    parser.add_argument("--rank", type=int, default=os.environ["RANK"])
196
    parser.add_argument("--world-size", type=int, required=True)
197
    parser.add_argument("--distributed-backend", type=str, default="nccl")
198
    parser.add_argument("--bucket-size", type=int, default=25)
199
    parser.add_argument("--master-addr", type=str, required=True)
200
    parser.add_argument("--master-port", type=str, required=True)
201
    parser.add_argument("--model", type=str)
202
    parser.add_argument(
203
        "--json", type=str, metavar="PATH", help="Write file with benchmark results"
204
    )
205
    args = parser.parse_args()
206

207
    num_gpus_per_node = torch.cuda.device_count()
208
    assert num_gpus_per_node == 8, "Expected 8 GPUs per machine"
209

210
    # The global process group used only for communicating benchmark
211
    # metadata, like measurements. Not for benchmarking itself.
212
    dist.init_process_group(
213
        backend="gloo",
214
        init_method=f"tcp://{args.master_addr}:{args.master_port}",
215
        rank=args.rank,
216
        world_size=args.world_size,
217
    )
218

219
    output = allgather_run("nvidia-smi topo -m")
220
    if not allequal(output):
221
        print('Output of "nvidia-smi topo -m" differs between machines')
222
        sys.exit(1)
223

224
    if args.rank == 0:
225
        print("-----------------------------------")
226
        print("PyTorch distributed benchmark suite")
227
        print("-----------------------------------")
228
        print("")
229
        print(f"* PyTorch version: {torch.__version__}")
230
        print(f"* CUDA version: {torch.version.cuda}")
231
        print(f"* Distributed backend: {args.distributed_backend}")
232
        print(f"* Maximum bucket size: {args.bucket_size}MB")
233
        print("")
234
        print("--- nvidia-smi topo -m ---")
235
        print("")
236
        print(output[0])
237
        print("--------------------------")
238
        print("")
239

240
    torch.cuda.set_device(dist.get_rank() % 8)
241
    device = torch.device("cuda:%d" % (dist.get_rank() % 8))
242

243
    benchmarks = []
244
    if args.model:
245
        benchmarks.append(
246
            TorchvisionBenchmark(
247
                device=device,
248
                distributed_backend=args.distributed_backend,
249
                bucket_size=args.bucket_size,
250
                model=args.model,
251
            )
252
        )
253
    else:
254
        for model in ["resnet50", "resnet101", "resnext50_32x4d", "resnext101_32x8d"]:
255
            benchmarks.append(
256
                TorchvisionBenchmark(
257
                    device=device,
258
                    distributed_backend=args.distributed_backend,
259
                    bucket_size=args.bucket_size,
260
                    model=model,
261
                )
262
            )
263

264
    benchmark_results = []
265
    for benchmark in benchmarks:
266
        if args.rank == 0:
267
            print(f"\nBenchmark: {str(benchmark)}")
268
        result = sweep(benchmark)
269
        benchmark_results.append(
270
            {
271
                "model": benchmark.model,
272
                "batch_size": benchmark.batch_size,
273
                "result": result,
274
            }
275
        )
276

277
    # Write file with benchmark results if applicable
278
    if args.rank == 0 and args.json:
279
        report = {
280
            "pytorch_version": torch.__version__,
281
            "cuda_version": torch.version.cuda,
282
            "distributed_backend": args.distributed_backend,
283
            "bucket_size": args.bucket_size,
284
            "benchmark_results": benchmark_results,
285
        }
286
        with open(args.json, "w") as f:
287
            json.dump(report, f)
288

289

290
if __name__ == "__main__":
291
    main()
292

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.