pytorch
784 строки · 27.0 Кб
1# Owner(s): ["module: c10d"]
2import unittest3from typing import List4
5import torch6import torch.distributed as dist7import torch.distributed._functional_collectives as funcol8from torch._C import FileCheck9from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code10from torch.distributed._functional_collectives import (11all_gather_into_tensor_coalesced,12all_gather_tensor,13all_reduce,14all_reduce_coalesced,15all_to_all_single,16AsyncCollectiveTensor,17reduce_scatter_tensor,18reduce_scatter_tensor_coalesced,19)
20from torch.testing._internal.common_distributed import (21MultiProcessTestCase,22requires_nccl,23run_with_native_funcol,24skip_if_lt_x_gpu,25)
26from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]27run_tests,28TestCase,29)
30from torch.testing._internal.distributed.fake_pg import FakeStore31from torch.utils._triton import has_triton32
33
34def load_test_module(name):35import sys36from importlib.machinery import SourceFileLoader37from pathlib import Path38from unittest import mock39
40testdir = Path(__file__).absolute().parent.parent41with mock.patch("sys.path", [*sys.path, str(testdir)]):42return SourceFileLoader(43name, str(testdir / f"{name.replace('.', '/')}.py")44).load_module()45
46
47AOTIRunnerUtil = load_test_module("inductor.test_aot_inductor_utils").AOTIRunnerUtil48
49import sys50
51if not dist.is_available():52print("distributed package not available, skipping tests", file=sys.stderr)53sys.exit(0)54
55
56@requires_nccl()57class C10DFunctionalNativeTest(MultiProcessTestCase):58def setUp(self) -> None:59super().setUp()60self._spawn_processes()61
62@property63def world_size(self) -> int:64return 265
66@property67def ranks(self) -> List[int]:68return list(range(self.world_size))69
70@property71def device(self) -> torch.device:72return torch.device(f"cuda:{self.rank}")73
74def _init_process_group(self) -> None:75# Allow testing aoti after torch.compile76torch._inductor.config.triton.store_cubin = True77torch._inductor.config.debug = True78
79torch.cuda.set_device(self.device)80store = dist.FileStore(self.file_name, self.world_size)81dist.init_process_group(82backend="nccl",83world_size=self.world_size,84rank=self.rank,85store=store,86)87torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)88
89@skip_if_lt_x_gpu(2)90@run_with_native_funcol91def test_all_reduce_single(self) -> None:92self._init_process_group()93
94input = torch.full((10, 10), float(self.rank), device=self.device)95output = torch.ops._c10d_functional.all_reduce(96input,97"avg",98"default",99)100output = torch.ops._c10d_functional.wait_tensor(output)101assert id(output) != id(input)102expect = sum(self.ranks) / self.world_size103assert output.eq(expect).all()104
105# Test Python API and AsyncCollectiveTensor106output = all_reduce(107input,108"avg",109"default",110)111assert isinstance(output, AsyncCollectiveTensor)112assert not output.completed113assert output.eq(expect).all()114assert output.completed115
116@skip_if_lt_x_gpu(2)117@run_with_native_funcol118def test_all_reduce_single_(self) -> None:119self._init_process_group()120
121input = torch.full((10, 10), float(self.rank), device=self.device)122output = torch.ops._c10d_functional.all_reduce_(123input,124"avg",125"default",126)127output = torch.ops._c10d_functional.wait_tensor(output)128assert id(output) == id(input)129expect = sum(self.ranks) / self.world_size130assert output.eq(expect).all()131
132@skip_if_lt_x_gpu(2)133@run_with_native_funcol134def test_all_reduce_coalesced(self) -> None:135self._init_process_group()136
137inputs = [138torch.full((i, i), float(self.rank * i), device=self.device)139for i in range(10)140]141outputs = torch.ops._c10d_functional.all_reduce_coalesced(142inputs,143"avg",144"default",145)146for i, (output, input) in enumerate(zip(outputs, inputs)):147output = torch.ops._c10d_functional.wait_tensor(output)148assert id(output) != id(input)149assert output.eq(sum(self.ranks) / self.world_size * i).all()150
151# Test Python API and AsyncCollectiveTensor152outputs = all_reduce_coalesced(153inputs,154"avg",155"default",156)157for i, (output, input) in enumerate(zip(outputs, inputs)):158assert not output.completed159assert output.eq(sum(self.ranks) / self.world_size * i).all()160assert output.completed161
162@skip_if_lt_x_gpu(2)163@run_with_native_funcol164def test_all_reduce_coalesced_(self) -> None:165self._init_process_group()166
167inputs = [168torch.full((i, i), float(self.rank * i), device=self.device)169for i in range(10)170]171outputs = torch.ops._c10d_functional.all_reduce_coalesced_(172inputs,173"avg",174"default",175)176for i, (output, input) in enumerate(zip(outputs, inputs)):177output = torch.ops._c10d_functional.wait_tensor(output)178assert id(output) == id(input)179assert output.eq(sum(self.ranks) / self.world_size * i).all()180
181@skip_if_lt_x_gpu(2)182@run_with_native_funcol183def test_all_gather_into_tensor_single(self) -> None:184self._init_process_group()185
186input = torch.full((10, 10), float(self.rank), device=self.device)187output = torch.ops._c10d_functional.all_gather_into_tensor(188input,189self.world_size,190"default",191)192output = torch.ops._c10d_functional.wait_tensor(output)193expect = torch.cat(194[195torch.full((10, 10), float(rank), device=self.device)196for rank in self.ranks197]198)199assert torch.allclose(output, expect)200assert output.eq(expect).all()201
202# Test Python API and AsyncCollectiveTensor203output = all_gather_tensor(204input,2050,206"default",207)208assert isinstance(output, AsyncCollectiveTensor)209assert not output.completed210assert output.eq(expect).all()211assert output.completed212
213@skip_if_lt_x_gpu(2)214@run_with_native_funcol215def test_all_gather_into_tensor_coalesced(self) -> None:216self._init_process_group()217
218inputs = [219torch.full((10, 10), float(self.rank * i), device=self.device)220for i in range(10)221]222outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(223inputs,224self.world_size,225"default",226)227expect = [228torch.cat(229[230torch.full((10, 10), float(rank) * i, device=self.device)231for rank in self.ranks232]233)234for i in range(10)235]236for i, output in enumerate(outputs):237output = torch.ops._c10d_functional.wait_tensor(output)238assert output.eq(expect[i]).all()239
240# Test Python API and AsyncCollectiveTensor241outputs = all_gather_into_tensor_coalesced(242inputs,243"default",244)245for i, output in enumerate(outputs):246assert not output.completed247assert output.eq(expect[i]).all()248assert output.completed249
250@skip_if_lt_x_gpu(2)251@run_with_native_funcol252def test_reduce_scatter_tensor_single(self) -> None:253self._init_process_group()254
255input = torch.tensor(self.ranks, device=self.device)256output = torch.ops._c10d_functional.reduce_scatter_tensor(257input,258"avg",259self.world_size,260"default",261)262output = torch.ops._c10d_functional.wait_tensor(output)263assert output.eq(self.rank).all()264
265# Test Python API and AsyncCollectiveTensor266output = reduce_scatter_tensor(267input,268"avg",2690,270"default",271)272assert isinstance(output, AsyncCollectiveTensor)273assert not output.completed274assert output.eq(self.rank).all()275assert output.completed276
277@skip_if_lt_x_gpu(2)278@run_with_native_funcol279def test_reduce_scatter_tensor_coalesced(self) -> None:280self._init_process_group()281
282inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)]283outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(284inputs,285"avg",286self.world_size,287"default",288)289for i, output in enumerate(outputs):290output = torch.ops._c10d_functional.wait_tensor(output)291assert output.eq(self.rank * i).all()292
293# Test Python API and AsyncCollectiveTensor294outputs = reduce_scatter_tensor_coalesced(295inputs,296"avg",297[0] * 10,298"default",299)300for i, output in enumerate(outputs):301assert not output.completed302assert output.eq(self.rank * i).all()303assert output.completed304
305@skip_if_lt_x_gpu(2)306@run_with_native_funcol307def test_all_to_all_single(self) -> None:308self._init_process_group()309torch.cuda.set_device(self.device)310
311torch.manual_seed(42)312send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))313
314input_split_sizes = send_sz_matrix[self.rank].tolist()315output_split_sizes = send_sz_matrix[:, self.rank].tolist()316input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()317
318output = torch.ops._c10d_functional.all_to_all_single(319input,320output_split_sizes,321input_split_sizes,322"default",323)324output = torch.ops._c10d_functional.wait_tensor(output)325expect = torch.cat(326[327torch.full((sz,), float(rank)).cuda()328for rank, sz in enumerate(output_split_sizes)329]330)331assert output.eq(expect).all()332
333# Test Python API and AsyncCollectiveTensor334output = all_to_all_single(335input, output_split_sizes, input_split_sizes, "default"336)337assert not output.completed338assert output.eq(expect).all()339assert output.completed340
341@skip_if_lt_x_gpu(2)342@run_with_native_funcol343def test_broadcast(self) -> None:344self._init_process_group()345
346input = torch.full((10, 10), float(self.rank), device=self.device)347output = torch.ops._c10d_functional.broadcast(348input,3491,350"default",351)352output = torch.ops._c10d_functional.wait_tensor(output)353assert id(output) != id(input)354expect = 1355assert output.eq(expect).all()356
357# Test Python API and AsyncCollectiveTensor358output = funcol.broadcast(359input,3601,361"default",362)363assert isinstance(output, AsyncCollectiveTensor)364assert not output.completed365assert output.eq(expect).all()366assert output.completed367
368@skip_if_lt_x_gpu(2)369@run_with_native_funcol370def test_unwaited(self) -> None:371# Verify that the process can terminate gracefully372# even with unwaited tensors373self._init_process_group()374
375input = torch.full((10, 10), float(self.rank), device=self.device)376output = torch.ops._c10d_functional.all_reduce(377input,378"avg",379"default",380)381
382
383class C10DFunctionalNativeCompileTest(TestCase):384def setUp(self):385# Allow testing aoti after torch.compile386torch._inductor.config.triton.store_cubin = True387torch._inductor.config.debug = True388
389self.rank = 0390self.world_size = 2391torch.cuda.set_device("cuda:0")392
393store = FakeStore()394dist.init_process_group(395backend="fake",396world_size=self.world_size,397rank=self.rank,398store=store,399)400
401def tearDown(self):402dist.destroy_process_group()403
404@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")405@fresh_inductor_cache()406@run_with_native_funcol407def test_inductor_all_reduce_single(self):408def func(arg: torch.Tensor) -> torch.Tensor:409buf0 = arg + 42410# Expect in-place with inductor allocated buf411ar0 = funcol.all_reduce(buf0, "avg", "0")412ar0 = funcol.wait_tensor(ar0)413# Expect no in-place with graph input414ar1 = funcol.all_reduce(arg, "avg", "0")415ar1 = funcol.wait_tensor(ar1)416return ar0, ar1417
418arg = torch.rand(4, 4, device="cuda")419compiled = torch.compile(func)420
421code = run_and_get_triton_code(compiled, arg)422(423FileCheck()424.check("buf0 = empty")425.check("buf7 = empty")426# Expect in-place with inductor allocated buf427.check("torch.ops._c10d_functional.all_reduce_.default(buf0")428.check("torch.ops._c10d_functional.wait_tensor.default(buf0")429# Expect no in-place with graph input (buf5 is a clone)430.check("torch.ops._c10d_functional.all_reduce_.default(buf7")431.check("torch.ops._c10d_functional.wait_tensor.default(buf7")432# Expect no extra copy on return433.check("return (buf0, buf7, )")434.run(code)435)436
437# Test aoti438out = AOTIRunnerUtil.run("cuda", func, (arg,))439torch.cuda.synchronize()440
441@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")442@fresh_inductor_cache()443@run_with_native_funcol444def test_inductor_all_reduce_coalesced(self):445def func(args: List[torch.Tensor]) -> torch.Tensor:446bufs = [arg + 42 for arg in args]447# Expect in-place with inductor allocated buf448ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0")449ar0 = [funcol.wait_tensor(out) for out in ar0]450# Expect no in-place with graph input451ar1 = funcol.all_reduce_coalesced(args, "avg", "0")452ar1 = [funcol.wait_tensor(out) for out in ar1]453return ar0, ar1454
455args = [torch.rand(4, 4, device="cuda") for _ in range(2)]456compiled = torch.compile(func)457code = run_and_get_triton_code(compiled, args)458(459FileCheck()460.check("buf0 = empty")461.check("buf5 = empty")462.check("buf1 = empty")463.check("buf6 = empty")464# Expect in-place with inductor allocated buf465.check(466"torch.ops._c10d_functional.all_reduce_coalesced_"467".default([buf0, buf1]"468)469# Expect no in-place with graph input (buf5, buf6 are clones)470.check(471"torch.ops._c10d_functional.all_reduce_coalesced_"472".default([buf5, buf6]"473)474.check("torch.ops._c10d_functional.wait_tensor.default(buf0")475.check("torch.ops._c10d_functional.wait_tensor.default(buf1")476.check("torch.ops._c10d_functional.wait_tensor.default(buf5")477.check("torch.ops._c10d_functional.wait_tensor.default(buf6")478# Expect no extra copy on return479.check("return (buf0, buf1, buf5, buf6, )")480.run(code)481)482
483# Test aoti484out = AOTIRunnerUtil.run("cuda", func, (args,))485torch.cuda.synchronize()486
487@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")488@fresh_inductor_cache()489@run_with_native_funcol490def test_inductor_inplace_op_on_view(self):491def func(arg: torch.Tensor) -> torch.Tensor:492buf0 = (arg + 10)[:2]493ar0 = funcol.all_reduce(buf0, "avg", "0")494ar0 = funcol.wait_tensor(ar0)495return ar0496
497arg = torch.rand(4, 4, device="cuda")498compiled = torch.compile(func)499
500code = run_and_get_triton_code(compiled, arg)501(502FileCheck()503.check("buf0 = empty")504# Ensure the all_reduce_ input is a view505.check(506"torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0"507)508.check(509"torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0"510)511.check("return (reinterpret_tensor(buf0")512.run(code)513)514
515@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")516@fresh_inductor_cache()517@run_with_native_funcol518def test_inductor_reuse_buffer_after_inplace_collective(self):519def func(arg: torch.Tensor) -> torch.Tensor:520# Expect allocation521buf0 = arg + 42522ar0 = funcol.all_reduce(buf0, "avg", "0")523ar0 = funcol.wait_tensor(ar0)524# Expect allocation525buf1 = torch.mm(arg, ar0)526# Expect buf0 to be reused527buf2 = torch.mm(arg, buf1)528return buf1, buf2529
530arg = torch.rand(4, 4, device="cuda")531compiled = torch.compile(func)532code = run_and_get_triton_code(compiled, arg)533(534FileCheck()535# Expect allocation536.check("buf0 = empty")537.check("torch.ops._c10d_functional.all_reduce_.default(buf0")538.check("torch.ops._c10d_functional.wait_tensor.default(buf0")539# Expect allocation540.check("buf7 = empty")541.check("extern_kernels.mm(arg0_1, buf0, out=buf7")542# Expect buf0 to be reused543.check("buf8 = buf0; del buf0 # reuse")544.check("extern_kernels.mm(arg0_1, buf7, out=buf8")545# Expect no extra copy on return546.check("return (buf7, buf8, )")547.run(code)548)549
550@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")551@fresh_inductor_cache()552@run_with_native_funcol553def test_inductor_all_gather_into_tensor_single(self):554def func(arg: torch.Tensor) -> torch.Tensor:555ag0 = funcol.all_gather_tensor(arg, 0, "0")556ag0 = funcol.wait_tensor(ag0)557return ag0558
559arg = torch.rand(4, 4, device="cuda")560compiled = torch.compile(func)561code = run_and_get_triton_code(compiled, arg)562(563FileCheck()564.check(565"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1"566)567.check("torch.ops._c10d_functional.wait_tensor.default(buf0")568# Expect no extra copy on return569.check("return (buf0, )")570.run(code)571)572
573# Test aoti574out = AOTIRunnerUtil.run("cuda", func, (arg,))575torch.cuda.synchronize()576
577@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")578@fresh_inductor_cache()579@run_with_native_funcol580def test_inductor_all_gather_into_tensor_coalesced(self):581def func(args: List[torch.Tensor]) -> torch.Tensor:582ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")583ag0 = [funcol.wait_tensor(out) for out in ag0]584return ag0585
586args = [torch.rand(4, 4, device="cuda") for _ in range(4)]587compiled = torch.compile(func)588code = run_and_get_triton_code(compiled, args)589(590FileCheck()591.check(592"buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"593".default([arg0_1, arg1_1, arg2_1, arg3_1]"594)595.check("buf1 = buf0[0]")596.check("buf2 = buf0[1]")597.check("buf3 = buf0[2]")598.check("buf4 = buf0[3]")599.check("torch.ops._c10d_functional.wait_tensor.default(buf1")600.check("torch.ops._c10d_functional.wait_tensor.default(buf2")601.check("torch.ops._c10d_functional.wait_tensor.default(buf3")602.check("torch.ops._c10d_functional.wait_tensor.default(buf4")603# Expect no extra copy on return604.check("return (buf1, buf2, buf3, buf4, )")605.run(code)606)607
608# Test aoti609out = AOTIRunnerUtil.run("cuda", func, (args,))610torch.cuda.synchronize()611
612@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")613@fresh_inductor_cache()614@run_with_native_funcol615def test_inductor_reduce_scatter_tensor_single(self):616def func(arg: torch.Tensor) -> torch.Tensor:617rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")618rs0 = funcol.wait_tensor(rs0)619return rs0620
621arg = torch.rand(4, 4, device="cuda")622compiled = torch.compile(func)623code = run_and_get_triton_code(compiled, arg)624(625FileCheck()626.check(627"buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1"628)629.check("torch.ops._c10d_functional.wait_tensor.default(buf0")630# Expect no extra copy on return631.check("return (buf0, )")632.run(code)633)634
635# Test aoti636out = AOTIRunnerUtil.run("cuda", func, (arg,))637torch.cuda.synchronize()638
639@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")640@fresh_inductor_cache()641@run_with_native_funcol642def test_inductor_reduce_scatter_tensor_coalesced(self):643def func(args: List[torch.Tensor]) -> torch.Tensor:644rs0 = funcol.reduce_scatter_tensor_coalesced(645args, "avg", [0] * len(args), "0"646)647rs0 = [funcol.wait_tensor(out) for out in rs0]648return rs0649
650args = [torch.rand(4, 4, device="cuda") for _ in range(4)]651compiled = torch.compile(func)652code = run_and_get_triton_code(compiled, args)653(654FileCheck()655.check(656"buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced"657".default([arg0_1, arg1_1, arg2_1, arg3_1]"658)659.check("buf1 = buf0[0]")660.check("buf2 = buf0[1]")661.check("buf3 = buf0[2]")662.check("buf4 = buf0[3]")663.check("torch.ops._c10d_functional.wait_tensor.default(buf1")664.check("torch.ops._c10d_functional.wait_tensor.default(buf2")665.check("torch.ops._c10d_functional.wait_tensor.default(buf3")666.check("torch.ops._c10d_functional.wait_tensor.default(buf4")667# Expect no extra copy on return668.check("return (buf1, buf2, buf3, buf4, )")669.run(code)670)671
672# Test aoti673AOTIRunnerUtil.run("cuda", func, (args,))674torch.cuda.synchronize()675
676@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")677@fresh_inductor_cache()678@run_with_native_funcol679def test_inductor_all_to_all_single(self):680def _tolist_with_constrain_as_size(tensor):681lst = tensor.tolist()682for elem in lst:683torch._constrain_as_size(elem)684return lst685
686def func(687input: torch.Tensor,688output_split_sizes: torch.Tensor,689input_split_sizes: torch.Tensor,690) -> torch.Tensor:691output = funcol.all_to_all_single(692input,693_tolist_with_constrain_as_size(output_split_sizes),694_tolist_with_constrain_as_size(input_split_sizes),695"0",696)697return funcol.wait_tensor(output)698
699torch.manual_seed(42)700send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))701
702input_split_sizes = send_sz_matrix[self.rank]703output_split_sizes = send_sz_matrix[:, self.rank].contiguous()704input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()705
706with torch._dynamo.config.patch(707dynamic_shapes=True,708capture_dynamic_output_shape_ops=True,709capture_scalar_outputs=True,710):711compiled = torch.compile(func, dynamic=True)712code = run_and_get_triton_code(713compiled, input, output_split_sizes, input_split_sizes714)715(716FileCheck()717.check_regex(718"torch.ops._c10d_functional.all_to_all_single.default\\("719"arg\\d+_\\d+, \\[u\\d+, u\\d+\\], \\[u\\d+, u\\d+\\]"720)721.check("torch.ops._c10d_functional.wait_tensor.default(")722.run(code)723)724
725@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")726@fresh_inductor_cache()727@run_with_native_funcol728def test_inductor_broadcast(self):729def func(arg: torch.Tensor) -> torch.Tensor:730buf0 = arg + 42731# Expect in-place with inductor allocated buf732br0 = funcol.broadcast(buf0, 1, "0")733br0 = funcol.wait_tensor(br0)734# Expect no in-place with graph input735br1 = funcol.broadcast(arg, 0, "0")736br1 = funcol.wait_tensor(br1)737return br0, br1738
739arg = torch.rand(4, 4, device="cuda")740compiled = torch.compile(func)741
742code = run_and_get_triton_code(compiled, arg)743(744FileCheck()745.check("buf0 = empty")746.check("buf7 = empty")747# Expect in-place with inductor allocated buf748.check("torch.ops._c10d_functional.broadcast_.default(buf0")749.check("torch.ops._c10d_functional.wait_tensor.default(buf0")750# Expect no in-place with graph input (buf5 is a clone)751.check("torch.ops._c10d_functional.broadcast_.default(buf7")752.check("torch.ops._c10d_functional.wait_tensor.default(buf7")753# Expect no extra copy on return754.check("return (buf0, buf7, )")755.run(code)756)757
758# Test aoti759out = AOTIRunnerUtil.run("cuda", func, (arg,))760torch.cuda.synchronize()761
762@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")763@fresh_inductor_cache()764@run_with_native_funcol765def test_ranks_and_tag(self):766def func(arg: torch.Tensor) -> torch.Tensor:767buf0 = arg + 42768# Expect in-place with inductor allocated buf769ar0 = funcol.all_reduce(buf0, "avg", [0, 1], "")770ar0 = funcol.wait_tensor(ar0)771# Expect no in-place with graph input772ar1 = funcol.all_reduce(arg, "avg", [0, 1], "")773ar1 = funcol.wait_tensor(ar1)774return ar0, ar1775
776arg = torch.rand(4, 4, device="cuda")777compiled = torch.compile(func, fullgraph=True)778
779code = run_and_get_triton_code(compiled, arg)780(FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code))781
782
783if __name__ == "__main__":784run_tests()785