colossalai
76 строк · 2.6 Кб
1import pytest2import torch3
4from colossalai.device.device_mesh import DeviceMesh5from colossalai.initialize import launch6from colossalai.logging import disable_existing_loggers7from colossalai.tensor.shape_consistency import ShapeConsistencyManager8from colossalai.tensor.sharding_spec import ShardingSpec9from colossalai.testing import rerun_if_address_is_in_use, spawn10
11
12def check_apply(rank, world_size, port):13disable_existing_loggers()14launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")15
16physical_mesh_id = torch.arange(0, 4)17mesh_shape = (2, 2)18# [[0, 1,19# [2, 3]]20device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)21entire_shape = torch.Size((4, 2))22shape_consistency_manager = ShapeConsistencyManager()23dim_partition_source = {0: [0]}24dim_partition_target = {1: [0]}25
26# DistSpec:27# shard_sequence: S0,R28# device_mesh_shape: (2, 2)29sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)30
31# DistSpec:32# shard_sequence: R,S033# device_mesh_shape: (2, 2)34sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)35
36if rank in (0, 1):37sharded_tensor_0 = torch.zeros(2, 1)38sharded_tensor_1 = torch.ones(2, 1)39# tensor([[0., 1.],40# [0., 1.]])41tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()42if rank in (2, 3):43sharded_tensor_0 = torch.ones(2, 1) * 244sharded_tensor_1 = torch.ones(2, 1) * 345# tensor([[2., 3.],46# [2., 3.]])47tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()48
49if rank in (0, 1):50# tensor([[0.],51# [0.],52# [2.],53# [2.]])54tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()55if rank in (2, 3):56# tensor([[1.],57# [1.],58# [3.],59# [3.]])60tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()61
62tensor_to_comm.sharding_spec = sharding_spec_source63tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)64assert tensor_to_comm.equal(tensor_to_check)65assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)66
67
68@pytest.mark.dist69@rerun_if_address_is_in_use()70def test_apply():71world_size = 472spawn(check_apply, world_size)73
74
75if __name__ == "__main__":76test_apply()77