15
import paddle.distributed.fleet as fleet
16
from paddle.distributed.fleet.meta_parallel import TensorParallel
17
from paddle.distributed.parallel import sync_params_buffers
18
from paddle.distributed.sharding import group_sharded_parallel
19
from ppfleetx.distributed.apis import env
22
def wrap_with_fleet(dist_config, model, optimizer=None, scaler=None):
23
if dist_config.sharding.sharding_stage in [2, 3]:
24
assert dist_config.pp_degree == 1, "sharding stage2/3 will support pipeline parallel later"
25
return wrap_sharding_2_3(dist_config, model, optimizer, scaler)
27
return wrap_3D_parallel(dist_config, model, optimizer, scaler)
30
def wrap_sharding_2_3(dist_config, model, optimizer=None, scaler=None):
32
dp_group = hcg.get_data_parallel_group()
33
sharding_group = hcg.get_sharding_parallel_group()
35
if dist_config.dp_degree > 1 and dist_config.sharding.sharding_stage == 3:
36
sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0])
38
if dist_config.mp_degree > 1:
39
assert dist_config.sharding.sharding_stage == 2, "only support mp + sharding stage2 hybrid parallel now."
40
model = TensorParallel(model, hcg, strategy=None)
42
level = "p_g_os" if dist_config.sharding.sharding_stage == 3 else "os_g"
44
model, optimizer, scaler = group_sharded_parallel(
50
offload=dist_config.sharding.sharding_offload,
51
dp_group=dp_group if dp_group.nranks > 1 else None,
54
if dist_config.sharding.reduce_overlap:
55
model._set_reduce_overlap(dist_config.sharding.reduce_overlap)
57
if dist_config.sharding.broadcast_overlap:
58
optimizer._set_broadcast_overlap(dist_config.sharding.broadcast_overlap, layers=origin_model, num_groups=2)
60
return model, optimizer, scaler
63
def wrap_3D_parallel(dist_config, model, optimizer=None, scaler=None):
64
model = fleet.distributed_model(model)
65
optimizer = fleet.distributed_optimizer(optimizer) if optimizer is not None else optimizer
66
scaler = fleet.distributed_scaler(scaler) if scaler is not None else scaler
68
return model, optimizer, scaler