20
import paddle.distributed as dist
21
from paddle.distributed import fleet
22
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
23
from ppfleetx.distributed.apis import comm_groups
24
from ppfleetx.utils.log import logger
25
from paddlenlp.trainer.trainer_utils import _get_distributed_seeds
27
__all__ = ["init_dist_env"]
41
global_seed, local_seed, random_seed = _get_distributed_seeds(seed)
44
global_seed = global_seed + 1024 + paddle.distributed.get_world_size()
45
local_seed = local_seed + 1024 + paddle.distributed.get_world_size()
47
tracker = get_rng_state_tracker()
48
tracker.add("global_seed", global_seed)
49
tracker.add("local_seed", local_seed)
51
paddle.seed(global_seed)
52
random.seed(random_seed)
53
np.random.seed(random_seed)
55
logger.info("The global seed is set to {}, local seed is set to {} and "
56
"random seed is set to {}.".format(global_seed, local_seed, random_seed))
61
_dp_seed = global_seed
84
def init_dist_env(config):
85
paddle.set_device(config.Global.device)
86
strategy = fleet.DistributedStrategy()
87
def is_segment_parallel_supported():
89
members = [name for (name, date) in inspect.getmembers(fleet.HybridCommunicateGroup)]
90
support_sep = "get_sep_parallel_world_size" in members
92
logger.warning("segment parallel is not supported!!!, Ignore it.")
95
if config.Distributed.mp_degree == 1 and config.Distributed.sharding.sharding_degree == 1:
96
if is_segment_parallel_supported():
97
order = ["pp", "dp", "sharding", "sep", "mp"]
99
order = ["pp", "dp", "sharding", "mp"]
101
if is_segment_parallel_supported():
102
order = ["dp", "pp", "sharding", "sep", "mp"]
104
order = ["dp", "pp", "sharding", "mp"]
106
if is_segment_parallel_supported():
107
strategy.hybrid_configs = {
108
"dp_degree": config.Distributed.dp_degree,
109
"mp_degree": config.Distributed.mp_degree,
110
"pp_degree": config.Distributed.pp_degree,
111
"sharding_degree": config.Distributed.sharding.sharding_degree,
112
"sep_degree": config.Distributed.sep_degree,
116
strategy.hybrid_configs = {
117
"dp_degree": config.Distributed.dp_degree,
118
"mp_degree": config.Distributed.mp_degree,
119
"pp_degree": config.Distributed.pp_degree,
120
"sharding_degree": config.Distributed.sharding.sharding_degree,
124
if config.Distributed.pp_degree > 1:
125
if "sequence_parallel" in config.Model:
126
if config.Model.sequence_parallel:
127
assert config.Global.enable_partial_send_recv is False, (
128
"if config.Distributed.pp_degree > 1 and config.Model.sequence_parallel is True, "
129
"config.Global.enable_partial_send_recv should be set False."
132
strategy.pipeline_configs = {
133
"accumulate_steps": config.Global.local_batch_size // config.Global.micro_batch_size,
134
"micro_batch_size": config.Global.micro_batch_size,
135
"enable_partial_send_recv": config.Global.enable_partial_send_recv,
139
seed = config.Global.seed
140
strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
142
hcg = comm_groups.create_hcg(strategy, hcg_name=config.Distributed.hcg)
147
return int(os.getenv("PADDLE_RANK_IN_NODE", 0))
150
def get_data_world_size():
151
if paddle.distributed.get_world_size() == 1:
155
dp_size = hcg.get_data_parallel_world_size()
156
sharding_size = hcg.get_sharding_parallel_world_size()
158
return dp_size * sharding_size
161
def get_data_world_rank():
162
if paddle.distributed.get_world_size() == 1:
166
dp_rank = hcg.get_data_parallel_rank()
167
sharding_rank = hcg.get_sharding_parallel_rank()
168
sharding_size = hcg.get_sharding_parallel_world_size()
170
return dp_rank * sharding_size + sharding_rank
173
def work_at_local_rank0(func):
174
def wrapper(*args, **kwargs):
176
if paddle.base.core.is_compiled_with_dist() and paddle.distributed.get_world_size() > 1:
177
local_rank = paddle.distributed.ParallelEnv().dev_id
179
func(*args, **kwargs)