17
from collections import namedtuple
21
import paddle.distributed as dist
22
import paddle.distributed.auto_parallel as auto
24
from paddlenlp.ops import Topology
25
from paddlenlp.trainer.trainer_utils import _get_distributed_seeds
26
from ppfleetx.utils.log import logger
33
if _mesh is None and paddle.distributed.get_world_size() == 1:
51
def __init__(self, rank, dp_degree, mp_degree, pp_degree):
52
self._dp_dim = "dp" if dp_degree > 1 else None
53
self._mp_dim = "mp" if mp_degree > 1 else None
54
self._dp_degree = dp_degree
55
self._mp_degree = mp_degree
56
self._pp_degree = pp_degree
58
arr = np.arange(0, pp_degree * dp_degree * mp_degree).reshape([dp_degree, pp_degree, mp_degree])
59
arr = arr.transpose(1, 0, 2)
60
self.world_process_mesh = auto.ProcessMesh(arr, dim_names=["pp", "dp", "mp"])
61
self.g_process_mesh = auto.ProcessMesh(list(range(pp_degree * dp_degree * mp_degree)))
62
ipp, idp, imp = np.where(arr == rank)
67
if dp_degree > 1 and mp_degree > 1:
68
self.pp_process_mesh = self.world_process_mesh
70
self.pp_process_mesh = self.world_process_mesh[:, idp, :]
72
self.pp_process_mesh = self.world_process_mesh[:, :, imp]
76
return self._dp_degree
80
return self._mp_degree
85
return self._mp_degree
89
return self._pp_degree
104
def __getitem__(self, idx):
105
return self.pp_process_mesh[idx]
108
def init_dist_env(config):
109
paddle.set_device(config.Global.device)
113
config.Distributed.dp_degree,
114
config.Distributed.mp_degree,
115
config.Distributed.pp_degree,
118
paddle.distributed.fleet.init(is_collective=True)
122
return int(os.getenv("PADDLE_RANK_IN_NODE", 0))
127
if dist.get_world_size() > 1:
131
dist.get_world_size(),
132
dp_degree=_mesh.dp_degree,
133
pp_degree=_mesh.pp_degree,
134
mp_degree=_mesh.mp_degree,
138
global_seed, local_seed, random_seed = _get_distributed_seeds(seed, topo)
141
global_seed = global_seed + 1024 + paddle.distributed.get_world_size()
142
local_seed = local_seed + 1024 + paddle.distributed.get_world_size()
144
paddle.seed(global_seed)
145
random.seed(random_seed)
146
np.random.seed(random_seed)
148
logger.info("The global seed is set to {}, local seed is set to {} and "
149
"random seed is set to {}.".format(global_seed, local_seed, random_seed))
154
_dp_seed = global_seed