colossalai

Форк
0
/
test_comm_spec_apply.py 
215 строк · 7.2 Кб
1
import pytest
2
import torch
3
import torch.distributed as dist
4

5
from colossalai.device.device_mesh import DeviceMesh
6
from colossalai.initialize import launch
7
from colossalai.logging import disable_existing_loggers
8
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
9
from colossalai.tensor.sharding_spec import ShardingSpec
10
from colossalai.testing import rerun_if_address_is_in_use, spawn
11

12

13
def check_all_gather(device_mesh, rank):
14
    # tensor to comm
15
    if rank in (0, 2):
16
        sharded_tensor_to_comm = torch.ones(2, 2).cuda()
17
    else:
18
        sharded_tensor_to_comm = torch.zeros(2, 2).cuda()
19

20
    # tensor to check
21
    tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()
22

23
    # test all gather
24
    dim_partition_dict = {1: [1]}
25

26
    # DistSpec:
27
    #     shard_sequence: R,S1
28
    #     device_mesh_shape: (2, 2)
29
    sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
30

31
    # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
32
    comm_spec = CommSpec(
33
        CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1
34
    )
35
    sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
36

37
    assert sharded_tensor_to_comm.equal(tensor_to_check)
38

39

40
def check_shard(device_mesh, rank):
41
    # tensor to comm
42
    sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda()
43
    sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda()
44
    # tensor([[0., 0., 1., 1.],
45
    #         [0., 0., 1., 1.]])
46
    tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)
47

48
    # test shard
49
    dim_partition_dict = {}
50

51
    # DistSpec:
52
    #     shard_sequence: R,R
53
    #     device_mesh_shape: (2, 2)
54
    sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict)
55

56
    # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
57
    comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
58
    tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
59

60
    if rank in (0, 2):
61
        assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
62
    if rank in (1, 3):
63
        assert tensor_to_shard.equal(sharded_tensor_to_comm_1)
64

65

66
def check_all_to_all(device_mesh, rank):
67
    # tensor to comm
68
    if rank in (0, 1):
69
        sharded_tensor_0 = torch.zeros(2, 1)
70
        sharded_tensor_1 = torch.ones(2, 1)
71
        # tensor([[0., 1.],
72
        #         [0., 1.]])
73
        tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
74
    if rank in (2, 3):
75
        sharded_tensor_0 = torch.ones(2, 1) * 2
76
        sharded_tensor_1 = torch.ones(2, 1) * 3
77
        # tensor([[2., 3.],
78
        #         [2., 3.]])
79
        tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
80

81
    if rank in (0, 1):
82
        # tensor([[0.],
83
        #         [0.],
84
        #         [2.],
85
        #         [2.]])
86
        tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
87
    if rank in (2, 3):
88
        # tensor([[1.],
89
        #         [1.],
90
        #         [3.],
91
        #         [3.]])
92
        tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
93

94
    # test shard
95
    dim_partition_dict = {0: [0]}
96

97
    # DistSpec:
98
    #     shard_sequence: S0,R
99
    #     device_mesh_shape: (2, 2)
100
    sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)
101

102
    # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
103
    comm_spec = CommSpec(
104
        CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, sharding_spec, gather_dim=0, shard_dim=1, logical_process_axis=0
105
    )
106
    tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
107

108
    assert tensor_to_comm.equal(tensor_to_check)
109

110

111
def check_all_reduce_fwd(device_mesh, rank):
112
    # tensor to comm
113
    tensor_to_comm = torch.ones(2, 2).cuda() * rank
114

115
    # reduce through logical process axis 0
116
    # tensor to check
117
    if rank in (0, 2):
118
        # tensor([[2., 2.],
119
        #         [2., 2.]])
120
        tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda()
121
    if rank in (1, 3):
122
        # tensor([[4., 4.],
123
        #         [4., 4.]])
124
        tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda()
125

126
    dim_partition_dict = {}
127
    # DistSpec:
128
    #     shard_sequence: R,R
129
    #     device_mesh_shape: (2, 2)
130
    sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
131

132
    comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
133
    tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
134

135
    assert tensor_to_comm.equal(tensor_to_check)
136

137

138
def check_all_reduce_bwd(device_mesh, rank):
139
    # tensor to comm
140
    tensor_to_comm = torch.ones(2, 2).cuda() * rank
141

142
    tensor_to_check = torch.ones(2, 2).cuda() * rank
143

144
    dim_partition_dict = {}
145
    # DistSpec:
146
    #     shard_sequence: R,R
147
    #     device_mesh_shape: (2, 2)
148
    sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
149

150
    comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
151
    tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
152

153
    assert tensor_to_comm.equal(tensor_to_check)
154

155

156
def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
157
    # tensor to comm
158
    tensor_to_comm = torch.ones(2, 2).cuda() * rank
159

160
    # reduce through logical process axis 0 at flatten device mesh
161
    # tensor to check
162
    # tensor([[6., 6.],
163
    #         [6., 6.]])
164
    tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
165

166
    dim_partition_dict = {}
167
    # DistSpec:
168
    #     shard_sequence: R,R
169
    #     device_mesh_shape: (2, 2)
170
    sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
171

172
    # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
173
    comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
174
    tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
175

176
    assert tensor_to_comm.equal(tensor_to_check)
177

178

179
def check_comm(rank, world_size, port):
180
    disable_existing_loggers()
181
    launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
182

183
    physical_mesh_id = torch.arange(0, 4)
184
    assert rank == dist.get_rank()
185

186
    mesh_shape = (2, 2)
187
    # [[0, 1,
188
    #  [2, 3]]
189
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
190
    # test all gather
191
    check_all_gather(device_mesh, rank)
192

193
    # test shard
194
    check_shard(device_mesh, rank)
195

196
    # test all to all
197
    check_all_to_all(device_mesh, rank)
198

199
    # test all reduce
200
    check_all_reduce_fwd(device_mesh, rank)
201
    check_all_reduce_bwd(device_mesh, rank)
202

203
    # test all reduce in 1D flatten device mesh
204
    check_all_reduce_in_flatten_device_mesh(device_mesh, rank)
205

206

207
@pytest.mark.dist
208
@rerun_if_address_is_in_use()
209
def test_comm_spec():
210
    world_size = 4
211
    spawn(check_comm, world_size)
212

213

214
if __name__ == "__main__":
215
    test_comm_spec()
216

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.