pytorch

Форк
0
/
test_compute_comm_reordering.py 
395 строк · 16.8 Кб
1
# Owner(s): ["module: inductor"]
2
import unittest
3
from unittest.mock import patch
4

5
import torch
6
import torch._dynamo
7
import torch._dynamo.logging
8
import torch._dynamo.test_case
9

10
# for some reason importing functional collectives after dynamo breaks collectives handling!
11
import torch.distributed._functional_collectives as _functional_collectives
12
from torch._C import FileCheck
13
from torch._dynamo.utils import same
14
from torch._inductor import ir, scheduler
15
from torch._inductor.comm_analysis import (
16
    baseLat,
17
    hwLat,
18
    llMaxBws,
19
    NCCL_ALGO,
20
    NCCL_HW,
21
    NCCL_PROTO,
22
    NVIDIA_GPU_TYPE,
23
)
24
from torch._inductor.utils import run_and_get_triton_code
25
from torch.testing._internal.common_distributed import (
26
    _dynamo_dist_per_rank_init,
27
    at_least_x_gpu,
28
    DynamoDistributedMultiProcTestCase,
29
    requires_nccl,
30
)
31
from torch.utils._triton import has_triton
32

33

34
def get_snode_runtime_for_reorder_compute_test(snode):
35
    # NOTE: custom cost model to show that the compute reordering algorithm is working
36
    # Collective kernels
37
    if isinstance(snode.node, ir._CollectiveKernel):
38
        return 100
39
    elif isinstance(snode.node, ir._WaitKernel):
40
        return 0
41
    # High-arithmetic-intensity compute kernels
42
    elif isinstance(snode.node, ir.ExternKernel):
43
        return 5
44
    # All other kernels
45
    return 1
46

47

48
def create_grouped_node_for_allreduce_and_its_deps(snodes):
49
    name_to_snode = {snode.node.name: snode for snode in snodes}
50
    all_reduce_snodes = [
51
        snode
52
        for snode in snodes
53
        if isinstance(snode.node, ir._CollectiveKernel)
54
        and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default
55
    ]
56
    assert len(all_reduce_snodes) == 1
57
    all_reduce_snode = all_reduce_snodes[0]
58
    all_reduce_dep_snodes = [
59
        name_to_snode[node.name] for node in all_reduce_snode.node.inputs
60
    ]
61
    assert len(all_reduce_dep_snodes) == 1
62
    all_reduce_dep_snode = all_reduce_dep_snodes[0]
63

64
    grouped_snode = scheduler.GroupedSchedulerNode.create(
65
        [all_reduce_dep_snode, all_reduce_snode]
66
    )
67
    new_snode_order = []
68
    new_snode_order.append(grouped_snode)
69
    for snode in snodes:
70
        if snode in grouped_snode.snodes:
71
            continue
72
        new_snode_order.append(snode)
73
    return new_snode_order
74

75

76
@requires_nccl()
77
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
78
    """
79
    Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
80
    """
81

82
    def get_world_trs(self):
83
        return {
84
            "tag": "",
85
            "ranks": list(range(self.world_size)),
86
            "group_size": self.world_size,
87
        }
88

89
    @property
90
    def world_size(self) -> int:
91
        # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
92
        # works around issue with skipif<2 and workers with unpredictable #s gpu
93
        return 2
94

95
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
96
    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
97
    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
98
    @patch.object(torch._inductor.config, "compile_threads", 1)
99
    @patch.object(torch._inductor.config, "reorder_for_locality", False)
100
    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
101
    @patch.object(
102
        torch._inductor.config,
103
        "reorder_for_compute_comm_overlap_passes",
104
        [
105
            "sink_waits",
106
        ],
107
    )
108
    def test_sink_waits(self):
109
        def func(a):
110
            ar = _functional_collectives.all_reduce(a, "sum", "0")
111
            b = torch.matmul(a, a)
112
            return torch.matmul(ar, b)
113

114
        with _dynamo_dist_per_rank_init(
115
            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
116
        ):
117
            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
118
            compiled = torch.compile(func)
119
            code = run_and_get_triton_code(compiled, inputs)
120
            # Verify that the wait_tensor is sinked below the 1st matmul but
121
            # above the 2nd matmul.
122
            (
123
                FileCheck()
124
                .check("torch.ops._c10d_functional.all_reduce_.default")
125
                .check("extern_kernels.mm")
126
                .check("torch.ops._c10d_functional.wait_tensor.default")
127
                .check("extern_kernels.mm")
128
                .run(code)
129
            )
130
            out = compiled(inputs)
131
            correct = func(inputs)
132
            self.assertTrue(same(out, correct))
133

134
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
135
    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
136
    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
137
    @patch.object(torch._inductor.config, "compile_threads", 1)
138
    @patch.object(torch._inductor.config, "reorder_for_locality", False)
139
    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
140
    @patch.object(
141
        torch._inductor.config,
142
        "reorder_for_compute_comm_overlap_passes",
143
        [
144
            "raise_comms",
145
        ],
146
    )
147
    def test_raise_comms(self):
148
        def func(a):
149
            b = torch.matmul(a, a)
150
            c = torch.relu(b)
151
            d = torch.matmul(c, c)
152
            e = _functional_collectives.all_reduce(b, "sum", "0")
153
            return torch.matmul(d, e)
154

155
        with _dynamo_dist_per_rank_init(
156
            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
157
        ):
158
            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
159
            compiled = torch.compile(func)
160
            code = run_and_get_triton_code(compiled, inputs)
161
            print(code)
162
            # Verify that the all_reduce_ has been raised above the 2nd matmul
163
            # but below the 1st matmul. Note that the all_reduce_ directly
164
            # writes to the output buffer of the 1st matmul, which is an input
165
            # to the first relu. Therefore, the all_reduce_ should be scheduled
166
            # after the first relu.
167
            (
168
                FileCheck()
169
                .check("extern_kernels.mm")
170
                .check("triton_poi_fused_relu")
171
                .check("torch.ops._c10d_functional.all_reduce_.default")
172
                .check("extern_kernels.mm")
173
                .check("torch.ops._c10d_functional.wait_tensor.default")
174
                .check("extern_kernels.mm")
175
                .run(code)
176
            )
177
            out = compiled(inputs)
178
            correct = func(inputs)
179
            self.assertTrue(same(out, correct))
180

181
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
182
    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
183
    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
184
    @patch.object(torch._inductor.config, "compile_threads", 1)
185
    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
186
    @patch.object(
187
        torch._inductor.config,
188
        "reorder_for_compute_comm_overlap_passes",
189
        [
190
            "sink_waits",
191
            "raise_comms",
192
        ],
193
    )
194
    def test_sink_waits_raise_comms(self):
195
        def func(a, *, tag, ranks, group_size):
196
            b = torch.matmul(a, a)
197
            c = torch.relu(b)
198
            d = torch.matmul(c, c)
199
            e = _functional_collectives.all_reduce(b, "sum", "0")
200
            f = torch.relu(d)
201
            g = torch.matmul(f, f)
202
            return torch.mm(e, g)
203

204
        with _dynamo_dist_per_rank_init(
205
            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
206
        ):
207
            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
208
            compiled = torch.compile(func)
209
            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
210
            # Things to verify:
211
            # - The clone prologue of the all_reduce_ should not be fused with
212
            # any relus.
213
            # - The all_reduce_ and its prologue should be raised above the 2nd
214
            # matmul but below the 1st matmul.
215
            # - The wait_tensor should be sinked below the 3rd matmul but above
216
            # the 4th matmul.
217
            (
218
                FileCheck()
219
                .check("extern_kernels.mm")
220
                .check("triton_poi_fused_all_reduce_0")
221
                .check("torch.ops._c10d_functional.all_reduce_.default")
222
                .check("triton_poi_fused_relu")
223
                .check("extern_kernels.mm")
224
                .check("triton_poi_fused_relu")
225
                .check("extern_kernels.mm")
226
                .check("torch.ops._c10d_functional.wait_tensor.default")
227
                .check("extern_kernels.mm")
228
                .run(code)
229
            )
230
            out = compiled(inputs, **self.get_world_trs())
231
            correct = func(inputs, **self.get_world_trs())
232
            self.assertTrue(same(out, correct))
233

234
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
235
    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
236
    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
237
    @patch.object(torch._inductor.config, "compile_threads", 1)
238
    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
239
    @patch.object(
240
        torch._inductor.config,
241
        "reorder_for_compute_comm_overlap_passes",
242
        [
243
            "reorder_compute_for_overlap",
244
        ],
245
    )
246
    def test_reorder_compute_for_overlap(self):
247
        def func(a, *, tag, ranks, group_size):
248
            ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
249
            g = torch.matmul(a, a)
250
            c = torch.relu(a)
251
            d = torch.matmul(c, c)
252
            f = d * c * ar
253
            fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
254
            e = torch.matmul(d + ar + fr, g)
255
            return (e,)
256

257
        with _dynamo_dist_per_rank_init(
258
            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
259
        ):
260
            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
261
            compiled = torch.compile(func)
262
            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
263
            # NOTE: after scheduling the first all_reduce:
264
            # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
265
            # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
266
            # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
267
            # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
268
            (
269
                FileCheck()
270
                .check("torch.ops._c10d_functional.all_reduce_.default")
271
                .check("triton_poi_fused_relu")
272
                .check("extern_kernels.mm")
273
                .check("extern_kernels.mm")
274
                .check("torch.ops._c10d_functional.wait_tensor.default")
275
                .check("triton_poi_fused_mul")
276
                .check("torch.ops._c10d_functional.all_reduce_.default")
277
                .check("torch.ops._c10d_functional.wait_tensor.default")
278
                .check("triton_poi_fused_add")
279
                .check("extern_kernels.mm")
280
                .run(code)
281
            )
282
            out = compiled(inputs, **self.get_world_trs())
283
            correct = func(inputs, **self.get_world_trs())
284
            self.assertTrue(same(out, correct))
285

286
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
287
    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
288
    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
289
    @patch.object(torch._inductor.config, "compile_threads", 1)
290
    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
291
    @patch.object(
292
        torch._inductor.config,
293
        "reorder_for_compute_comm_overlap_passes",
294
        [
295
            "reorder_compute_for_overlap",
296
        ],
297
    )
298
    @patch.object(
299
        torch._inductor.config,
300
        "estimate_op_runtime",
301
        get_snode_runtime_for_reorder_compute_test,
302
    )
303
    def test_reorder_compute_for_overlap_custom_runtime_estimation(self):
304
        def func(a, *, tag, ranks, group_size):
305
            ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
306
            g = torch.matmul(a, a)
307
            c = torch.relu(a)
308
            d = torch.matmul(c, c)
309
            f = d * c * ar
310
            fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
311
            e = torch.matmul(d + ar + fr, g)
312
            return (e,)
313

314
        with _dynamo_dist_per_rank_init(
315
            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
316
        ):
317
            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
318
            compiled = torch.compile(func)
319
            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
320
            # NOTE: after scheduling the first all_reduce:
321
            # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
322
            # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
323
            # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
324
            # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
325
            (
326
                FileCheck()
327
                .check("torch.ops._c10d_functional.all_reduce_.default")
328
                .check("triton_poi_fused_relu")
329
                .check("extern_kernels.mm")
330
                .check("extern_kernels.mm")
331
                .check("torch.ops._c10d_functional.wait_tensor.default")
332
                .check("triton_poi_fused_mul")
333
                .check("torch.ops._c10d_functional.all_reduce_.default")
334
                .check("torch.ops._c10d_functional.wait_tensor.default")
335
                .check("triton_poi_fused_add")
336
                .check("extern_kernels.mm")
337
                .run(code)
338
            )
339
            out = compiled(inputs, **self.get_world_trs())
340
            correct = func(inputs, **self.get_world_trs())
341
            self.assertTrue(same(out, correct))
342

343
    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
344
    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
345
    @patch.object(torch._inductor.config, "compile_threads", 1)
346
    @patch.object(
347
        torch._inductor.config,
348
        "_pre_fusion_custom_pass",
349
        create_grouped_node_for_allreduce_and_its_deps,
350
    )
351
    def test_grouped_scheduler_node(self):
352
        def func(a, *, tag, ranks, group_size):
353
            add = a + a
354
            div = add / a
355
            ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
356
            # Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
357
            # but here in this unit test, we intentionally put `add`, `div` and `ar` computation
358
            # into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
359
            mul = a * a
360
            mm = torch.matmul(mul, ar)
361
            return (mm,)
362

363
        with _dynamo_dist_per_rank_init(
364
            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
365
        ):
366
            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
367
            compiled = torch.compile(func)
368
            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
369
            # Expectations:
370
            # 1. `add = a + a` and `div = add / a` are still fused, which means fusion
371
            #    still happens among nodes within a GroupedSchedulerNode.
372
            # 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
373
            #    GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
374
            FileCheck().check("triton_poi_fused_add_div_0.").check(
375
                "_c10d_functional.all_reduce_."
376
            ).check("triton_poi_fused_mul_1.").run(code)
377
            out = compiled(inputs, **self.get_world_trs())
378
            correct = func(inputs, **self.get_world_trs())
379
            self.assertTrue(same(out, correct))
380

381
    def test_nccl_heuristics(self):
382
        assert len(baseLat) == len(NCCL_ALGO)
383
        assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
384

385
        assert len(hwLat) == len(NCCL_HW)
386
        assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
387
        assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
388

389
        assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
390

391

392
if __name__ == "__main__":
393
    from torch._dynamo.test_case import run_tests
394

395
    run_tests()
396

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.