colossalai
142 строки · 5.9 Кб
1import torch
2
3from colossalai.device.device_mesh import DeviceMesh
4from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
5from colossalai.tensor.sharding_spec import ShardingSpec
6
7physical_mesh_id = torch.arange(0, 16)
8mesh_shape = (4, 4)
9# [[0, 1, 2, 3],
10# [4, 5, 6, 7],
11# [8, 9, 10,11],
12# [12,13,14,15]]
13device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
14entire_shape = torch.Size((64, 32, 16))
15shape_consistency_manager = ShapeConsistencyManager()
16
17
18def test_one_step_transform():
19dim_partition_dict = {0: [0], 1: [1]}
20# DistSpec:
21# shard_sequence: S0,S1,R
22# device_mesh_shape: (4, 4)
23sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
24
25# {DistSpec:
26# shard_sequence: R,S1,R
27# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
28# shard_sequence: S0,R,R
29# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
30rst_dict = shape_consistency_manager.get_all_all_gather_spec(
31sharding_spec, {"forward": 0, "backward": 0, "total": 0}
32)
33
34assert "[R, S1, R]" in [
35str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
36]
37assert "[S0, R, R]" in [
38str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
39]
40
41dim_partition_dict_all2all = {0: [0], 1: [1]}
42# DistSpec:
43# shard_sequence: S0,S1,R
44# device_mesh_shape: (4, 4)
45sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all)
46# {DistSpec:
47# shard_sequence: S01,R,R
48# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec:
49# shard_sequence: R,S1,S0
50# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
51# shard_sequence: S0,R,S1
52# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
53rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(
54sharding_spec_all2all, {"forward": 0, "backward": 0, "total": 0}
55)
56
57assert "[S01, R, R]" in [
58str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
59]
60assert "[R, S1, S0]" in [
61str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
62]
63assert "[S0, R, S1]" in [
64str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
65]
66
67dim_partition_shard = {0: [0]}
68# DistSpec:
69# shard_sequence: S0,R,R
70# device_mesh_shape: (4, 4)
71sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard)
72# {DistSpec:
73# shard_sequence: S01,R,R
74# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec:
75# shard_sequence: S0,S1,R
76# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
77# shard_sequence: S0,R,S1
78# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
79rst_dict_shard = shape_consistency_manager.get_all_shard_spec(
80sharding_spec_shard, {"forward": 0, "backward": 0, "total": 0}
81)
82
83assert "[S01, R, R]" in [
84str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
85]
86assert "[S0, S1, R]" in [
87str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
88]
89assert "[S0, R, S1]" in [
90str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
91]
92
93
94def test_shape_consistency():
95dim_partition_source = {1: [0, 1]}
96dim_partition_target = {0: [0, 1]}
97
98# DistSpec:
99# shard_sequence: R,S01,R
100# device_mesh_shape: (4, 4)
101sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
102
103# DistSpec:
104# shard_sequence: S01,R,R
105# device_mesh_shape: (4, 4)
106sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
107
108transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
109sharding_spec_source, sharding_spec_target
110)
111
112transform_path_str = "->".join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
113assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]"
114
115# all-gather(S01) -> S0
116assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
117assert comm_action_sequence[0].gather_dim == 1
118assert comm_action_sequence[0].logical_process_axis == 1
119
120# all-to-all(R, S0) -> [S0, R]
121assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
122assert comm_action_sequence[1].gather_dim == 1
123assert comm_action_sequence[1].shard_dim == 0
124assert comm_action_sequence[1].logical_process_axis == 0
125
126# shard(S0) -> [S01]
127assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
128assert comm_action_sequence[2].shard_dim == 0
129assert comm_action_sequence[2].logical_process_axis == 1
130
131assert (
132shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
133)
134assert (
135shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][1]
136== comm_action_sequence
137)
138
139
140if __name__ == "__main__":
141test_one_step_transform()
142test_shape_consistency()
143