paddlenlp

Форк
0
68 строк · 2.8 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import paddle.distributed.fleet as fleet
16
from paddle.distributed.fleet.meta_parallel import TensorParallel
17
from paddle.distributed.parallel import sync_params_buffers
18
from paddle.distributed.sharding import group_sharded_parallel
19
from ppfleetx.distributed.apis import env
20

21

22
def wrap_with_fleet(dist_config, model, optimizer=None, scaler=None):
23
    if dist_config.sharding.sharding_stage in [2, 3]:
24
        assert dist_config.pp_degree == 1, "sharding stage2/3 will support pipeline parallel later"
25
        return wrap_sharding_2_3(dist_config, model, optimizer, scaler)
26
    else:
27
        return wrap_3D_parallel(dist_config, model, optimizer, scaler)
28

29

30
def wrap_sharding_2_3(dist_config, model, optimizer=None, scaler=None):
31
    hcg = env.get_hcg()
32
    dp_group = hcg.get_data_parallel_group()
33
    sharding_group = hcg.get_sharding_parallel_group()
34

35
    if dist_config.dp_degree > 1 and dist_config.sharding.sharding_stage == 3:
36
        sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0])
37

38
    if dist_config.mp_degree > 1:
39
        assert dist_config.sharding.sharding_stage == 2, "only support mp + sharding stage2 hybrid parallel now."
40
        model = TensorParallel(model, hcg, strategy=None)
41

42
    level = "p_g_os" if dist_config.sharding.sharding_stage == 3 else "os_g"
43
    origin_model = model
44
    model, optimizer, scaler = group_sharded_parallel(
45
        model=model,
46
        optimizer=optimizer,
47
        level=level,
48
        scaler=scaler,
49
        group=sharding_group,
50
        offload=dist_config.sharding.sharding_offload,
51
        dp_group=dp_group if dp_group.nranks > 1 else None,
52
    )
53

54
    if dist_config.sharding.reduce_overlap:
55
        model._set_reduce_overlap(dist_config.sharding.reduce_overlap)
56

57
    if dist_config.sharding.broadcast_overlap:
58
        optimizer._set_broadcast_overlap(dist_config.sharding.broadcast_overlap, layers=origin_model, num_groups=2)
59

60
    return model, optimizer, scaler
61

62

63
def wrap_3D_parallel(dist_config, model, optimizer=None, scaler=None):
64
    model = fleet.distributed_model(model)
65
    optimizer = fleet.distributed_optimizer(optimizer) if optimizer is not None else optimizer
66
    scaler = fleet.distributed_scaler(scaler) if scaler is not None else scaler
67

68
    return model, optimizer, scaler
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.