colossalai
37 строк · 1.0 Кб
1import pytest2import torch3import torch.distributed as dist4from torch.distributed import ReduceOp5
6from colossalai.device.device_mesh import DeviceMesh7from colossalai.initialize import launch8from colossalai.testing import rerun_if_address_is_in_use, spawn9
10
11def check_layer(rank, world_size, port):12launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")13
14physical_mesh_id = torch.arange(0, 4)15assert rank == dist.get_rank()16
17tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()18mesh_shape = (2, 2)19# [[0, 1,20# [2, 3]]21device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)22
23for axis in range(len(mesh_shape)):24tensor = torch.ones(4).cuda()25pg = device_mesh.get_process_group(axis=axis)26dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)27assert tensor.equal(tensor_to_check)28
29
30@pytest.mark.dist31@rerun_if_address_is_in_use()32def test_logical_pg():33spawn(check_layer, 4)34
35
36if __name__ == "__main__":37test_logical_pg()38