paddlenlp

Форк
0
181 строка · 5.7 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import random
17

18
import numpy as np
19
import paddle
20
import paddle.distributed as dist
21
from paddle.distributed import fleet
22
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
23
from ppfleetx.distributed.apis import comm_groups
24
from ppfleetx.utils.log import logger
25
from paddlenlp.trainer.trainer_utils import _get_distributed_seeds
26

27
__all__ = ["init_dist_env"]
28

29
_seed = None
30
_dp_seed = None
31
_hcg = None
32

33

34
def set_seed(seed):
35
    # NOTE(shenliang03): For parameter init seed:
36
    # seed: dp/mp_undistributed_paramter/sharding is same; others is different
37
    # For compute seed(dropout):
38
    # global seed: only mp group is same.
39
    # local seed: all groups are different
40

41
    global_seed, local_seed, random_seed = _get_distributed_seeds(seed)
42

43
    # NOTE: add (1024 + world_size) to seed for CI cases
44
    global_seed = global_seed + 1024 + paddle.distributed.get_world_size()
45
    local_seed = local_seed + 1024 + paddle.distributed.get_world_size()
46

47
    tracker = get_rng_state_tracker()
48
    tracker.add("global_seed", global_seed)
49
    tracker.add("local_seed", local_seed)
50

51
    paddle.seed(global_seed)
52
    random.seed(random_seed)
53
    np.random.seed(random_seed)
54

55
    logger.info("The global seed is set to {}, local seed is set to {} and "
56
                "random seed is set to {}.".format(global_seed, local_seed, random_seed))
57

58
    global _seed
59
    global _dp_seed
60
    _seed = seed
61
    _dp_seed = global_seed
62

63

64
def set_hcg(hcg):
65
    global _hcg
66
    _hcg = hcg
67

68

69
def get_hcg():
70
    global _hcg
71
    return _hcg
72

73

74
def get_seed():
75
    global _seed
76
    return _seed
77

78

79
def get_dp_seed():
80
    global _dp_seed
81
    return _dp_seed
82

83

84
def init_dist_env(config):
85
    paddle.set_device(config.Global.device)
86
    strategy = fleet.DistributedStrategy()
87
    def is_segment_parallel_supported():
88
        import inspect
89
        members = [name for (name, date) in inspect.getmembers(fleet.HybridCommunicateGroup)]
90
        support_sep = "get_sep_parallel_world_size" in members
91
        if not support_sep:
92
            logger.warning("segment parallel is not supported!!!, Ignore it.")
93
        return support_sep
94

95
    if config.Distributed.mp_degree == 1 and config.Distributed.sharding.sharding_degree == 1:
96
        if is_segment_parallel_supported():
97
            order = ["pp", "dp", "sharding", "sep", "mp"]
98
        else:
99
            order = ["pp", "dp", "sharding", "mp"]
100
    else:
101
        if is_segment_parallel_supported():
102
            order = ["dp", "pp", "sharding", "sep", "mp"]
103
        else:
104
            order = ["dp", "pp", "sharding", "mp"]
105

106
    if is_segment_parallel_supported():
107
        strategy.hybrid_configs = {
108
            "dp_degree": config.Distributed.dp_degree,
109
            "mp_degree": config.Distributed.mp_degree,
110
            "pp_degree": config.Distributed.pp_degree,
111
            "sharding_degree": config.Distributed.sharding.sharding_degree,
112
            "sep_degree": config.Distributed.sep_degree,
113
            "order": order,
114
        }
115
    else:
116
        strategy.hybrid_configs = {
117
            "dp_degree": config.Distributed.dp_degree,
118
            "mp_degree": config.Distributed.mp_degree,
119
            "pp_degree": config.Distributed.pp_degree,
120
            "sharding_degree": config.Distributed.sharding.sharding_degree,
121
            "order": order,
122
        }
123

124
    if config.Distributed.pp_degree > 1:
125
        if "sequence_parallel" in config.Model:
126
            if config.Model.sequence_parallel:
127
                assert config.Global.enable_partial_send_recv is False, (
128
                    "if config.Distributed.pp_degree > 1 and config.Model.sequence_parallel is True, "
129
                    "config.Global.enable_partial_send_recv should be set False."
130
                )
131

132
    strategy.pipeline_configs = {
133
        "accumulate_steps": config.Global.local_batch_size // config.Global.micro_batch_size,
134
        "micro_batch_size": config.Global.micro_batch_size,
135
        "enable_partial_send_recv": config.Global.enable_partial_send_recv,
136
    }
137

138
    # set control in tensor parallel
139
    seed = config.Global.seed
140
    strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
141

142
    hcg = comm_groups.create_hcg(strategy, hcg_name=config.Distributed.hcg)
143
    set_hcg(hcg)
144

145

146
def get_local_rank():
147
    return int(os.getenv("PADDLE_RANK_IN_NODE", 0))
148

149

150
def get_data_world_size():
151
    if paddle.distributed.get_world_size() == 1:
152
        return 1
153

154
    hcg = get_hcg()
155
    dp_size = hcg.get_data_parallel_world_size()
156
    sharding_size = hcg.get_sharding_parallel_world_size()
157

158
    return dp_size * sharding_size
159

160

161
def get_data_world_rank():
162
    if paddle.distributed.get_world_size() == 1:
163
        return 0
164

165
    hcg = get_hcg()
166
    dp_rank = hcg.get_data_parallel_rank()
167
    sharding_rank = hcg.get_sharding_parallel_rank()
168
    sharding_size = hcg.get_sharding_parallel_world_size()
169

170
    return dp_rank * sharding_size + sharding_rank
171

172

173
def work_at_local_rank0(func):
174
    def wrapper(*args, **kwargs):
175
        local_rank = 0
176
        if paddle.base.core.is_compiled_with_dist() and paddle.distributed.get_world_size() > 1:
177
            local_rank = paddle.distributed.ParallelEnv().dev_id
178
        if local_rank == 0:
179
            func(*args, **kwargs)
180

181
    return wrapper
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.