1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
15
import paddle.distributed as dist
16
from paddle.distributed import fleet
17
from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy
18
from paddle.distributed.fleet.base.strategy_group import (
26
def create_hcg(strategy, hcg_name):
27
if hcg_name == "HybridCommunicateGroup":
28
fleet.init(is_collective=True, strategy=strategy)
29
hcg = fleet.get_hybrid_communicate_group()
31
dist.init_parallel_env()
32
hcg = eval("{}".format(hcg_name))(strategy)
37
class Hybrid4DCommGroup(OrthogonalStrategy):
38
def __init__(self, list_of_strategy=None, fused_strategy_dict={}):
44
("sharding", 1, ShardingGroup),
46
if list_of_strategy is None
50
fused_strategy_dict["check"] = ["mp", "pp"]
52
super().__init__(list_of_strategy, fused_strategy_dict)
55
def get_data_parallel_rank(self):
56
return self.rank_in_strategy("dp")
58
def get_data_parallel_world_size(self):
59
return self.strategy_group("dp").world_size
61
def get_data_parallel_group(self):
62
return self.strategy_group("dp").group
64
def get_data_parallel_group_src_rank(self):
65
return self.strategy_group("dp").group.ranks[0]
68
def get_model_parallel_rank(self):
69
return self.rank_in_strategy("mp")
71
def get_model_parallel_world_size(self):
72
return self.strategy_group("mp").world_size
74
def get_model_parallel_group(self):
75
return self.strategy_group("mp").group
77
def get_model_parallel_group_src_rank(self):
78
return self.strategy_group("mp").group.ranks[0]
81
def get_stage_id(self):
82
return self.rank_in_strategy("pp")
84
def get_pipe_parallel_world_size(self):
85
return self.strategy_group("pp").world_size
87
def get_pipe_parallel_group(self):
88
return self.strategy_group("pp").group
90
def get_p2p_groups(self):
91
return self.strategy_group("pp").p2p_groups
93
# group sharded parallel
94
def get_sharding_parallel_rank(self):
95
return self.rank_in_strategy("sharding")
97
def get_sharding_parallel_world_size(self):
98
return self.strategy_group("sharding").world_size
100
def get_sharding_parallel_group(self):
101
return self.strategy_group("sharding")
103
def get_sharding_parallel_group_src_rank(self):
104
return self.strategy_group("sharding").ranks[0]
106
# check parallel group
107
def get_check_parallel_group(self):
108
return self.strategy_group("check").group