colossalai
215 строк · 7.2 Кб
1import pytest2import torch3import torch.distributed as dist4
5from colossalai.device.device_mesh import DeviceMesh6from colossalai.initialize import launch7from colossalai.logging import disable_existing_loggers8from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec9from colossalai.tensor.sharding_spec import ShardingSpec10from colossalai.testing import rerun_if_address_is_in_use, spawn11
12
13def check_all_gather(device_mesh, rank):14# tensor to comm15if rank in (0, 2):16sharded_tensor_to_comm = torch.ones(2, 2).cuda()17else:18sharded_tensor_to_comm = torch.zeros(2, 2).cuda()19
20# tensor to check21tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()22
23# test all gather24dim_partition_dict = {1: [1]}25
26# DistSpec:27# shard_sequence: R,S128# device_mesh_shape: (2, 2)29sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)30
31# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)32comm_spec = CommSpec(33CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=134)35sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)36
37assert sharded_tensor_to_comm.equal(tensor_to_check)38
39
40def check_shard(device_mesh, rank):41# tensor to comm42sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda()43sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda()44# tensor([[0., 0., 1., 1.],45# [0., 0., 1., 1.]])46tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)47
48# test shard49dim_partition_dict = {}50
51# DistSpec:52# shard_sequence: R,R53# device_mesh_shape: (2, 2)54sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict)55
56# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)57comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)58tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)59
60if rank in (0, 2):61assert tensor_to_shard.equal(sharded_tensor_to_comm_0)62if rank in (1, 3):63assert tensor_to_shard.equal(sharded_tensor_to_comm_1)64
65
66def check_all_to_all(device_mesh, rank):67# tensor to comm68if rank in (0, 1):69sharded_tensor_0 = torch.zeros(2, 1)70sharded_tensor_1 = torch.ones(2, 1)71# tensor([[0., 1.],72# [0., 1.]])73tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()74if rank in (2, 3):75sharded_tensor_0 = torch.ones(2, 1) * 276sharded_tensor_1 = torch.ones(2, 1) * 377# tensor([[2., 3.],78# [2., 3.]])79tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()80
81if rank in (0, 1):82# tensor([[0.],83# [0.],84# [2.],85# [2.]])86tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()87if rank in (2, 3):88# tensor([[1.],89# [1.],90# [3.],91# [3.]])92tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()93
94# test shard95dim_partition_dict = {0: [0]}96
97# DistSpec:98# shard_sequence: S0,R99# device_mesh_shape: (2, 2)100sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)101
102# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)103comm_spec = CommSpec(104CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, sharding_spec, gather_dim=0, shard_dim=1, logical_process_axis=0105)106tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)107
108assert tensor_to_comm.equal(tensor_to_check)109
110
111def check_all_reduce_fwd(device_mesh, rank):112# tensor to comm113tensor_to_comm = torch.ones(2, 2).cuda() * rank114
115# reduce through logical process axis 0116# tensor to check117if rank in (0, 2):118# tensor([[2., 2.],119# [2., 2.]])120tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda()121if rank in (1, 3):122# tensor([[4., 4.],123# [4., 4.]])124tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda()125
126dim_partition_dict = {}127# DistSpec:128# shard_sequence: R,R129# device_mesh_shape: (2, 2)130sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)131
132comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)133tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)134
135assert tensor_to_comm.equal(tensor_to_check)136
137
138def check_all_reduce_bwd(device_mesh, rank):139# tensor to comm140tensor_to_comm = torch.ones(2, 2).cuda() * rank141
142tensor_to_check = torch.ones(2, 2).cuda() * rank143
144dim_partition_dict = {}145# DistSpec:146# shard_sequence: R,R147# device_mesh_shape: (2, 2)148sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)149
150comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)151tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)152
153assert tensor_to_comm.equal(tensor_to_check)154
155
156def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):157# tensor to comm158tensor_to_comm = torch.ones(2, 2).cuda() * rank159
160# reduce through logical process axis 0 at flatten device mesh161# tensor to check162# tensor([[6., 6.],163# [6., 6.]])164tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()165
166dim_partition_dict = {}167# DistSpec:168# shard_sequence: R,R169# device_mesh_shape: (2, 2)170sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)171
172# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])173comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])174tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)175
176assert tensor_to_comm.equal(tensor_to_check)177
178
179def check_comm(rank, world_size, port):180disable_existing_loggers()181launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")182
183physical_mesh_id = torch.arange(0, 4)184assert rank == dist.get_rank()185
186mesh_shape = (2, 2)187# [[0, 1,188# [2, 3]]189device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)190# test all gather191check_all_gather(device_mesh, rank)192
193# test shard194check_shard(device_mesh, rank)195
196# test all to all197check_all_to_all(device_mesh, rank)198
199# test all reduce200check_all_reduce_fwd(device_mesh, rank)201check_all_reduce_bwd(device_mesh, rank)202
203# test all reduce in 1D flatten device mesh204check_all_reduce_in_flatten_device_mesh(device_mesh, rank)205
206
207@pytest.mark.dist208@rerun_if_address_is_in_use()209def test_comm_spec():210world_size = 4211spawn(check_comm, world_size)212
213
214if __name__ == "__main__":215test_comm_spec()216