pytorch
255 строк · 8.7 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4import time
5from statistics import mean
6from unittest.mock import patch
7
8import torch
9import torch.nn as nn
10from torch import distributed as dist
11from torch.cuda import Event
12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14from torch.testing._internal.common_fsdp import FSDPTest
15from torch.testing._internal.common_utils import (
16get_cycles_per_ms,
17run_tests,
18TEST_WITH_DEV_DBG_ASAN,
19)
20
21if not dist.is_available():
22print("Distributed not available, skipping tests", file=sys.stderr)
23sys.exit(0)
24
25if TEST_WITH_DEV_DBG_ASAN:
26print(
27"Skip dev-asan as torch + multiprocessing spawn have known issues",
28file=sys.stderr,
29)
30sys.exit(0)
31
32
33class Layer(nn.Module):
34def __init__(self, compute_cycles, has_params: bool):
35super().__init__()
36self.sleep_cycles = compute_cycles
37self.optional_param = None
38if has_params:
39self.optional_param = nn.Parameter(torch.rand(1))
40
41def forward(self, x):
42# Get 2 events.
43self.e1 = Event(enable_timing=True)
44self.e2 = Event(enable_timing=True)
45
46# Record the fake forward compute time.
47self.e1.record()
48if self.sleep_cycles > 0:
49torch.cuda._sleep(self.sleep_cycles)
50if self.optional_param is not None:
51x = x + self.optional_param # force the param to be part of the graph
52self.e2.record()
53return x
54
55def get_time(self):
56# return the recorded duration.
57return self.e1.elapsed_time(self.e2)
58
59
60def _create_model(compute_cycles, has_params: bool):
61# Use `limit_all_gathers=False` since the timing being tested relies on the
62# CPU running ahead of the GPU
63model = FSDP(
64nn.Sequential(
65FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
66FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
67FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
68FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
69),
70limit_all_gathers=False,
71).cuda()
72return model
73
74
75class Min10:
76def __init__(self):
77self.data = []
78
79def add(self, new_data):
80if len(self.data) < 10:
81self.data.append(new_data)
82else:
83self.data = sorted(self.data)
84if new_data < self.data[-1]:
85self.data[-1] = new_data
86
87def avg(self):
88return mean(self.data)
89
90
91class TestForwardOverlapWorldSizeOne(FSDPTest):
92@property
93def world_size(self):
94return 1
95
96def _dist_train(self):
97rank = self.rank
98world_size = self.world_size
99# Save the original torch.distributed.all_gather_into_tensor function since we will
100# patch it to include an artificial delay.
101orig_all_gather = torch.distributed.all_gather_into_tensor
102
103def run(compute_cycles, all_gather_cycles):
104has_params = all_gather_cycles > 0
105model = _create_model(compute_cycles, has_params)
106
107# Get the input and sets the input's requires_grad to True because
108# we have a fake compute in the forward pass.
109batch = torch.rand(1).cuda()
110batch.requires_grad = True
111
112# Run one dummy iteration to trigger the execution order validation
113# all-gathers
114out = model(batch)
115out.backward()
116model.zero_grad(set_to_none=True)
117
118# We run 20 iterations but only collect timing data from the minimal 10
119# data points because nondeterministic system events can disturb the timing.
120cpu_iter = Min10()
121cpu_wait = Min10()
122gpu_compute = Min10()
123gpu_total = Min10()
124for _ in range(20):
125# Get two events for measuring the overall time.
126e1 = Event(enable_timing=True)
127e2 = Event(enable_timing=True)
128
129cpu_start = time.process_time()
130
131all_gather_called = False
132
133def _delayed_all_gather(*args, **kwargs):
134nonlocal all_gather_called
135all_gather_called = True
136torch.cuda._sleep(all_gather_cycles)
137assert orig_all_gather
138return orig_all_gather(*args, **kwargs)
139
140# forward pass
141#
142# Even though both e1 & e2 are on the compute stream, since
143# compute depends on all_gather, e2-e1 includes all_gather time.
144e1.record()
145with patch(
146"torch.distributed.all_gather_into_tensor", _delayed_all_gather
147):
148out = model(batch)
149if has_params and world_size > 1:
150self.assertTrue(all_gather_called)
151else:
152self.assertFalse(all_gather_called)
153e2.record()
154
155# backward pass
156out.backward()
157model.zero_grad(set_to_none=True)
158
159cpu_iter_time = time.process_time() - cpu_start
160
161# wait for gpu
162out.item()
163cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time
164
165# get sum of the compute time
166times = []
167for mod in model.modules():
168if not isinstance(mod, Layer):
169continue
170times.append(mod.get_time())
171
172# get gpu compute + all_gather time
173overall_gpu_time = e1.elapsed_time(e2)
174
175cpu_iter.add(cpu_iter_time)
176cpu_wait.add(cpu_wait_for_gpu_time)
177gpu_compute.add(sum(times))
178gpu_total.add(overall_gpu_time)
179
180del model
181return {
182"cpu_iter": cpu_iter.avg(),
183"cpu_wait": cpu_wait.avg(),
184"gpu_compute": gpu_compute.avg(),
185"gpu_total": gpu_total.avg(),
186}
187
188sleep_cycles = int(100 * get_cycles_per_ms())
189
190e1 = run(0, 0) # no compute, no all-gather
191e2 = run(0, sleep_cycles) # no compute, only all-gather
192e3 = run(sleep_cycles, 0) # only compute, no all-gather
193e4 = run(sleep_cycles, sleep_cycles) # both compute and all-gather
194debug_string = f"\nrank{rank}:\n e1: {e1}\n e2: {e2}\n e3: {e3}\n e4: {e4}"
195print(debug_string)
196
197# Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
198# wait should be long, except when there is no real work on GPU.
199#
200# If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
201# e4["cpu_iter"] may not be short as cpu may take some time to queue both compute and all-gather.
202short = [
203e1["cpu_iter"],
204e2["cpu_iter"],
205e3["cpu_iter"],
206e1["cpu_wait"],
207]
208long = [e3["cpu_wait"], e4["cpu_wait"]]
209if world_size == 1:
210short.append(e2["cpu_wait"]) # all gather should not be happening.
211else:
212long.append(
213e2["cpu_wait"]
214) # all gather should happen and prolong the cpu-gpu wait.
215for s in short:
216for l in long:
217# 10X longer is a safe margin, since the GPU work timing is around 100X more
218# of that of the CPU.
219self.assertTrue(s * 10 < l)
220
221# Check the GPU timing.
222short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
223long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
224if world_size == 1:
225short.append(e2["gpu_total"]) # all gather should not be happening.
226else:
227long.append(
228e2["gpu_total"]
229) # all gather should happen and prolong the cpu-gpu wait.
230for s in short:
231for l in long:
232# 10X longer is a safe margin, since the time is around 100X longer
233# when there is work on GPU vs. no work.
234self.assertTrue(s * 10 < l)
235
236# Check the GPU overlapping when there is all-gather.
237if world_size > 1:
238compute_only = e3["gpu_compute"]
239all_gather_only = e2["gpu_total"]
240both = e4["gpu_total"]
241self.assertTrue(compute_only + all_gather_only > 1.1 * both)
242
243@skip_if_lt_x_gpu(2)
244def test_forward_overlap(self):
245self._dist_train()
246
247
248class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne):
249@property
250def world_size(self):
251return 2
252
253
254if __name__ == "__main__":
255run_tests()
256