colossalai

Форк
0
/
test_shape_consistency.py 
142 строки · 5.9 Кб
1
import torch
2

3
from colossalai.device.device_mesh import DeviceMesh
4
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
5
from colossalai.tensor.sharding_spec import ShardingSpec
6

7
physical_mesh_id = torch.arange(0, 16)
8
mesh_shape = (4, 4)
9
# [[0, 1, 2, 3],
10
#  [4, 5, 6, 7],
11
#  [8, 9, 10,11],
12
#  [12,13,14,15]]
13
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
14
entire_shape = torch.Size((64, 32, 16))
15
shape_consistency_manager = ShapeConsistencyManager()
16

17

18
def test_one_step_transform():
19
    dim_partition_dict = {0: [0], 1: [1]}
20
    # DistSpec:
21
    #     shard_sequence: S0,S1,R
22
    #     device_mesh_shape: (4, 4)
23
    sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
24

25
    # {DistSpec:
26
    #     shard_sequence: R,S1,R
27
    #     device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
28
    #     shard_sequence: S0,R,R
29
    #     device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
30
    rst_dict = shape_consistency_manager.get_all_all_gather_spec(
31
        sharding_spec, {"forward": 0, "backward": 0, "total": 0}
32
    )
33

34
    assert "[R, S1, R]" in [
35
        str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
36
    ]
37
    assert "[S0, R, R]" in [
38
        str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
39
    ]
40

41
    dim_partition_dict_all2all = {0: [0], 1: [1]}
42
    # DistSpec:
43
    #     shard_sequence: S0,S1,R
44
    #     device_mesh_shape: (4, 4)
45
    sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all)
46
    # {DistSpec:
47
    #         shard_sequence: S01,R,R
48
    #         device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec:
49
    #         shard_sequence: R,S1,S0
50
    #         device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
51
    #         shard_sequence: S0,R,S1
52
    #         device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
53
    rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(
54
        sharding_spec_all2all, {"forward": 0, "backward": 0, "total": 0}
55
    )
56

57
    assert "[S01, R, R]" in [
58
        str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
59
    ]
60
    assert "[R, S1, S0]" in [
61
        str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
62
    ]
63
    assert "[S0, R, S1]" in [
64
        str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
65
    ]
66

67
    dim_partition_shard = {0: [0]}
68
    # DistSpec:
69
    #     shard_sequence: S0,R,R
70
    #     device_mesh_shape: (4, 4)
71
    sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard)
72
    # {DistSpec:
73
    #         shard_sequence: S01,R,R
74
    #         device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec:
75
    #         shard_sequence: S0,S1,R
76
    #         device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
77
    #         shard_sequence: S0,R,S1
78
    #         device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
79
    rst_dict_shard = shape_consistency_manager.get_all_shard_spec(
80
        sharding_spec_shard, {"forward": 0, "backward": 0, "total": 0}
81
    )
82

83
    assert "[S01, R, R]" in [
84
        str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
85
    ]
86
    assert "[S0, S1, R]" in [
87
        str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
88
    ]
89
    assert "[S0, R, S1]" in [
90
        str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
91
    ]
92

93

94
def test_shape_consistency():
95
    dim_partition_source = {1: [0, 1]}
96
    dim_partition_target = {0: [0, 1]}
97

98
    # DistSpec:
99
    #     shard_sequence: R,S01,R
100
    #     device_mesh_shape: (4, 4)
101
    sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
102

103
    # DistSpec:
104
    #     shard_sequence: S01,R,R
105
    #     device_mesh_shape: (4, 4)
106
    sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
107

108
    transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
109
        sharding_spec_source, sharding_spec_target
110
    )
111

112
    transform_path_str = "->".join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
113
    assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]"
114

115
    # all-gather(S01) -> S0
116
    assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
117
    assert comm_action_sequence[0].gather_dim == 1
118
    assert comm_action_sequence[0].logical_process_axis == 1
119

120
    # all-to-all(R, S0) -> [S0, R]
121
    assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
122
    assert comm_action_sequence[1].gather_dim == 1
123
    assert comm_action_sequence[1].shard_dim == 0
124
    assert comm_action_sequence[1].logical_process_axis == 0
125

126
    # shard(S0) -> [S01]
127
    assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
128
    assert comm_action_sequence[2].shard_dim == 0
129
    assert comm_action_sequence[2].logical_process_axis == 1
130

131
    assert (
132
        shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
133
    )
134
    assert (
135
        shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][1]
136
        == comm_action_sequence
137
    )
138

139

140
if __name__ == "__main__":
141
    test_one_step_transform()
142
    test_shape_consistency()
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.