pytorch

Форк
0
/
test_fsdp_overlap.py 
255 строк · 8.7 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4
import time
5
from statistics import mean
6
from unittest.mock import patch
7

8
import torch
9
import torch.nn as nn
10
from torch import distributed as dist
11
from torch.cuda import Event
12
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14
from torch.testing._internal.common_fsdp import FSDPTest
15
from torch.testing._internal.common_utils import (
16
    get_cycles_per_ms,
17
    run_tests,
18
    TEST_WITH_DEV_DBG_ASAN,
19
)
20

21
if not dist.is_available():
22
    print("Distributed not available, skipping tests", file=sys.stderr)
23
    sys.exit(0)
24

25
if TEST_WITH_DEV_DBG_ASAN:
26
    print(
27
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
28
        file=sys.stderr,
29
    )
30
    sys.exit(0)
31

32

33
class Layer(nn.Module):
34
    def __init__(self, compute_cycles, has_params: bool):
35
        super().__init__()
36
        self.sleep_cycles = compute_cycles
37
        self.optional_param = None
38
        if has_params:
39
            self.optional_param = nn.Parameter(torch.rand(1))
40

41
    def forward(self, x):
42
        # Get 2 events.
43
        self.e1 = Event(enable_timing=True)
44
        self.e2 = Event(enable_timing=True)
45

46
        # Record the fake forward compute time.
47
        self.e1.record()
48
        if self.sleep_cycles > 0:
49
            torch.cuda._sleep(self.sleep_cycles)
50
        if self.optional_param is not None:
51
            x = x + self.optional_param  # force the param to be part of the graph
52
        self.e2.record()
53
        return x
54

55
    def get_time(self):
56
        # return the recorded duration.
57
        return self.e1.elapsed_time(self.e2)
58

59

60
def _create_model(compute_cycles, has_params: bool):
61
    # Use `limit_all_gathers=False` since the timing being tested relies on the
62
    # CPU running ahead of the GPU
63
    model = FSDP(
64
        nn.Sequential(
65
            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
66
            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
67
            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
68
            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
69
        ),
70
        limit_all_gathers=False,
71
    ).cuda()
72
    return model
73

74

75
class Min10:
76
    def __init__(self):
77
        self.data = []
78

79
    def add(self, new_data):
80
        if len(self.data) < 10:
81
            self.data.append(new_data)
82
        else:
83
            self.data = sorted(self.data)
84
            if new_data < self.data[-1]:
85
                self.data[-1] = new_data
86

87
    def avg(self):
88
        return mean(self.data)
89

90

91
class TestForwardOverlapWorldSizeOne(FSDPTest):
92
    @property
93
    def world_size(self):
94
        return 1
95

96
    def _dist_train(self):
97
        rank = self.rank
98
        world_size = self.world_size
99
        # Save the original torch.distributed.all_gather_into_tensor function since we will
100
        # patch it to include an artificial delay.
101
        orig_all_gather = torch.distributed.all_gather_into_tensor
102

103
        def run(compute_cycles, all_gather_cycles):
104
            has_params = all_gather_cycles > 0
105
            model = _create_model(compute_cycles, has_params)
106

107
            # Get the input and sets the input's requires_grad to True because
108
            # we have a fake compute in the forward pass.
109
            batch = torch.rand(1).cuda()
110
            batch.requires_grad = True
111

112
            # Run one dummy iteration to trigger the execution order validation
113
            # all-gathers
114
            out = model(batch)
115
            out.backward()
116
            model.zero_grad(set_to_none=True)
117

118
            # We run 20 iterations but only collect timing data from the minimal 10
119
            # data points because nondeterministic system events can disturb the timing.
120
            cpu_iter = Min10()
121
            cpu_wait = Min10()
122
            gpu_compute = Min10()
123
            gpu_total = Min10()
124
            for _ in range(20):
125
                # Get two events for measuring the overall time.
126
                e1 = Event(enable_timing=True)
127
                e2 = Event(enable_timing=True)
128

129
                cpu_start = time.process_time()
130

131
                all_gather_called = False
132

133
                def _delayed_all_gather(*args, **kwargs):
134
                    nonlocal all_gather_called
135
                    all_gather_called = True
136
                    torch.cuda._sleep(all_gather_cycles)
137
                    assert orig_all_gather
138
                    return orig_all_gather(*args, **kwargs)
139

140
                # forward pass
141
                #
142
                # Even though both e1 & e2 are on the compute stream, since
143
                # compute depends on all_gather, e2-e1 includes all_gather time.
144
                e1.record()
145
                with patch(
146
                    "torch.distributed.all_gather_into_tensor", _delayed_all_gather
147
                ):
148
                    out = model(batch)
149
                    if has_params and world_size > 1:
150
                        self.assertTrue(all_gather_called)
151
                    else:
152
                        self.assertFalse(all_gather_called)
153
                e2.record()
154

155
                # backward pass
156
                out.backward()
157
                model.zero_grad(set_to_none=True)
158

159
                cpu_iter_time = time.process_time() - cpu_start
160

161
                # wait for gpu
162
                out.item()
163
                cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time
164

165
                # get sum of the compute time
166
                times = []
167
                for mod in model.modules():
168
                    if not isinstance(mod, Layer):
169
                        continue
170
                    times.append(mod.get_time())
171

172
                # get gpu compute + all_gather time
173
                overall_gpu_time = e1.elapsed_time(e2)
174

175
                cpu_iter.add(cpu_iter_time)
176
                cpu_wait.add(cpu_wait_for_gpu_time)
177
                gpu_compute.add(sum(times))
178
                gpu_total.add(overall_gpu_time)
179

180
            del model
181
            return {
182
                "cpu_iter": cpu_iter.avg(),
183
                "cpu_wait": cpu_wait.avg(),
184
                "gpu_compute": gpu_compute.avg(),
185
                "gpu_total": gpu_total.avg(),
186
            }
187

188
        sleep_cycles = int(100 * get_cycles_per_ms())
189

190
        e1 = run(0, 0)  # no compute, no all-gather
191
        e2 = run(0, sleep_cycles)  # no compute, only all-gather
192
        e3 = run(sleep_cycles, 0)  # only compute, no all-gather
193
        e4 = run(sleep_cycles, sleep_cycles)  # both compute and all-gather
194
        debug_string = f"\nrank{rank}:\n  e1: {e1}\n  e2: {e2}\n  e3: {e3}\n  e4: {e4}"
195
        print(debug_string)
196

197
        # Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
198
        # wait should be long, except when there is no real work on GPU.
199
        #
200
        # If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
201
        # e4["cpu_iter"] may not be short as cpu may take some time to queue both compute and all-gather.
202
        short = [
203
            e1["cpu_iter"],
204
            e2["cpu_iter"],
205
            e3["cpu_iter"],
206
            e1["cpu_wait"],
207
        ]
208
        long = [e3["cpu_wait"], e4["cpu_wait"]]
209
        if world_size == 1:
210
            short.append(e2["cpu_wait"])  # all gather should not be happening.
211
        else:
212
            long.append(
213
                e2["cpu_wait"]
214
            )  # all gather should happen and prolong the cpu-gpu wait.
215
        for s in short:
216
            for l in long:
217
                # 10X longer is a safe margin, since the GPU work timing is around 100X more
218
                # of that of the CPU.
219
                self.assertTrue(s * 10 < l)
220

221
        # Check the GPU timing.
222
        short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
223
        long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
224
        if world_size == 1:
225
            short.append(e2["gpu_total"])  # all gather should not be happening.
226
        else:
227
            long.append(
228
                e2["gpu_total"]
229
            )  # all gather should happen and prolong the cpu-gpu wait.
230
        for s in short:
231
            for l in long:
232
                # 10X longer is a safe margin, since the time is around 100X longer
233
                # when there is work on GPU vs. no work.
234
                self.assertTrue(s * 10 < l)
235

236
        # Check the GPU overlapping when there is all-gather.
237
        if world_size > 1:
238
            compute_only = e3["gpu_compute"]
239
            all_gather_only = e2["gpu_total"]
240
            both = e4["gpu_total"]
241
            self.assertTrue(compute_only + all_gather_only > 1.1 * both)
242

243
    @skip_if_lt_x_gpu(2)
244
    def test_forward_overlap(self):
245
        self._dist_train()
246

247

248
class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne):
249
    @property
250
    def world_size(self):
251
        return 2
252

253

254
if __name__ == "__main__":
255
    run_tests()
256

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.