paddlenlp

Форк
0
331 строка · 13.8 Кб
1
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import os
17
import sys
18

19
import paddle
20
import paddle.distributed as dist
21
import paddle.distributed.auto_parallel as auto
22

23
from .config import (
24
    AttrDict,
25
    check_config,
26
    create_attr_dict,
27
    override_config,
28
    parse_config,
29
    print_config,
30
)
31
from .log import logger
32

33

34
def process_dist_configs(config):
35
    """
36
    process distributed strategy for auto parallel
37
    """
38
    nranks = dist.get_world_size()
39

40
    configs = config["Distributed"]
41

42
    mp_degree = configs.setdefault("mp_degree", 1)
43
    pp_degree = configs.setdefault("pp_degree", 1)
44

45
    # disenable sequence parallel is mp_degree < 2.
46
    sequence_parallel = config["Model"]["sequence_parallel"]
47
    if mp_degree < 2 and sequence_parallel:
48
        config["Model"]["sequence_parallel"] = False
49
        logger.warning(
50
            "sequence_parallel is turn off since mp_degree < 2."
51
        )
52

53

54
    # sharding default
55
    sharding_config = configs["sharding"]
56
    sharding_degree = sharding_config.setdefault("sharding_degree", 1)
57
    sharding_config.setdefault("sharding_stage", 2)
58
    sharding_config.setdefault("reduce_overlap", False)
59
    sharding_config.setdefault("broadcast_overlap", False)
60

61
    other_degree = mp_degree * pp_degree
62

63
    assert nranks % other_degree == 0, "Requires nranks should be divided by mp_degree*pp_degree."
64
    dp_degree = configs.setdefault("dp_degree", nranks // other_degree)
65
    assert nranks % dp_degree == 0, "unreasonable config of dist_strategy."
66
    assert nranks == dp_degree * other_degree, (
67
        "Mismatched config using {} cards with dp_degree[{}],"
68
        "mp_degree[{}], pp_degree[{}] and sharding_degree[{}]".format(
69
            nranks, dp_degree, mp_degree, pp_degree, sharding_degree
70
        )
71
    )
72

73

74
def process_global_configs(config):
75
    """
76
    process global configs for auto parallel
77
    """
78
    dp_degree = config["Distributed"]["dp_degree"]
79
    # pp_degree = config["Distributed"]["pp_degree"]
80
    # sharding_degree = config["Distributed"]["sharding"]["sharding_degree"]
81

82
    # TODO: support partial_send_recv
83
    # config["Global"]["enable_partial_send_recv"] = True
84
    # if config.get("Model", None) is not None and "sequence_parallel" in config["Model"] and pp_degree > 1:
85
    #     if config["Model"]["sequence_parallel"]:
86
    #         config["Global"]["enable_partial_send_recv"] = False
87
    #         logger.warning(
88
    #             "if config.Distributed.pp_degree > 1 and config.Model.sequence_parallel is True, "
89
    #             "config.Global.enable_partial_send_recv will be set False."
90
    #         )
91

92
    global_cfg = config["Global"]
93

94
    # Set environment variable
95
    flags = global_cfg.get("flags", {})
96
    paddle.set_flags(flags)
97
    for k, v in flags.items():
98
        logger.info("Environment variable {} is set {}.".format(k, v))
99

100
    if global_cfg["global_batch_size"] is None and global_cfg["local_batch_size"] is None:
101
        raise ValueError("global_batch_size or local_batch_size should be set.")
102
    elif global_cfg["global_batch_size"] is not None and global_cfg["local_batch_size"] is not None:
103
        assert (
104
            global_cfg["global_batch_size"] // global_cfg["local_batch_size"] == dp_degree
105
        ), "global_batch_size[{}] should be divided by local_batch_size[{}] when dp_degree is [{}]".format(
106
            global_cfg["global_batch_size"], global_cfg["local_batch_size"], dp_degree
107
        )
108
    elif global_cfg["global_batch_size"] is not None and global_cfg["local_batch_size"] is None:
109
        assert (
110
            global_cfg["global_batch_size"] % dp_degree == 0
111
        ), "global_batch_size[{}] should be divided by dp_degree[{}]".format(
112
            global_cfg["global_batch_size"], dp_degree
113
        )
114
        global_cfg["local_batch_size"] = global_cfg["global_batch_size"] // dp_degree
115
    else:
116
        global_cfg["global_batch_size"] = global_cfg["local_batch_size"] * dp_degree
117
    assert global_cfg["local_batch_size"] % global_cfg["micro_batch_size"] == 0
118

119

120
def process_engine_configs(config):
121
    """
122
    process engine configs for auto parallel
123
    """
124
    if config.Engine.get("verbose", None) is None:
125
        config.Engine["verbose"] = 2
126
    if config.Engine.get("logging_freq", None) is None:
127
        config.Engine["logging_freq"] = 10
128
    config.Engine["save_load"] = config.Engine.get("save_load", {})
129
    save_load_cfg = config.Engine.save_load
130
    save_steps = save_load_cfg.get("save_steps", None)
131
    save_epoch = save_load_cfg.get("save_epoch", None)
132
    if save_steps is None or save_steps == -1:
133
        save_load_cfg["save_steps"] = sys.maxsize if sys.version > "3" else sys.maxint
134

135
    if save_epoch is None or save_epoch == -1:
136
        save_load_cfg["save_epoch"] = 1
137

138
    save_load_cfg["output_dir"] = save_load_cfg.get("output_dir", "./output")
139
    save_load_cfg["ckpt_dir"] = save_load_cfg.get("ckpt_dir", None)
140

141
    config.Engine["max_steps"] = config.Engine.get("max_steps", 500000)
142
    config.Engine["eval_freq"] = config.Engine.get("eval_freq", -1)
143
    config.Engine["eval_iters"] = config.Engine.get("eval_iters", 0)
144
    config.Engine["logging_freq"] = config.Engine.get("logging_freq", 1)
145
    config.Engine["num_train_epochs"] = config.Engine.get("num_train_epochs", 1)
146
    config.Engine["test_iters"] = (
147
        config.Engine["eval_iters"] * 10
148
        if config.Engine.get("test_iters", None) is None
149
        else config.Engine["test_iters"]
150
    )
151
    config.Engine["accumulate_steps"] = config.Global.local_batch_size // config.Global.micro_batch_size
152

153
def use_pir():
154
    is_pir_mode = os.environ.get("FLAGS_enable_pir_in_executor", None)
155
    return str(is_pir_mode).lower() not in ('false', 'off', '0', 'none')
156

157
def process_strategy(config):
158
    """
159
    process auto strategy for auto parallel
160
    """
161
    strategy = auto.Strategy()
162
    strategy.auto_mode = "semi"
163
    # strategy.seed = config["Global"]["seed"]
164

165
    if config.get("FusedPasses", None) is not None:
166
        # fused passes config
167
        fused_passes_list = []
168
        fused_linear = config.FusedPasses.pop("fused_linear", False)
169
        fused_adamw = config.FusedPasses.pop("fused_adamw", False)
170
        if fused_linear:
171
            if use_pir():
172
                fused_passes_list.append("fused_gemm_epilogue_pass")
173
            else:
174
                fused_passes_list.append("fuse_gemm_epilogue")
175
        if fused_adamw:
176
            fused_passes_list.append("fuse_adamw")
177
        fused_passes = strategy.fused_passes
178
        fused_passes.enable = len(fused_passes_list) > 0
179
        fused_passes.fused_passes_list = fused_passes_list
180

181
    if config.get("Model", None) is not None:
182
        # recompute config
183
        if not config.Model.get("no_recompute_layers", None):
184
            config.Model["no_recompute_layers"] = []
185
        else:
186
            assert isinstance(config.Model["no_recompute_layers"], list), "no_recompute_layers should be a list"
187
            for i in config.Model["no_recompute_layers"]:
188
                assert isinstance(i, int), "all values in no_recompute_layers should be an integer"
189
            assert min(config.Model["no_recompute_layers"]) >= 0, "the min value in no_recompute_layers should >= 0"
190
            assert (
191
                max(config.Model["no_recompute_layers"]) < config.Model["num_layers"]
192
            ), "the max value in no_recompute_layers should < num_layers"
193
            config.Model["no_recompute_layers"] = sorted(list(set(config.Model["no_recompute_layers"])))
194
        recompute = strategy.recompute
195
        recompute.enable = config.Model.get("use_recompute", False)
196
        recompute.sr = config.Model.pop("sr", 0)
197
        recompute.refined_ops_patterns = config.Model.pop("refined_ops_patterns", []) # gpt.GPTModelAuto don't need this parameter
198
        recompute.no_recompute_segments = config.Model.pop("no_recompute_layers", [])
199
        recompute.enable_tuning = config.get("Tuning", False) and config.Tuning.get("tuning_recompute", False)
200

201
    # amp config
202
    amp_cfg = config.Engine.get("mix_precision", {})
203
    amp = strategy.amp
204
    amp.enable = amp_cfg.get("enable", False)
205
    amp.dtype = amp_cfg.get("dtype", "float16")
206
    amp.level = amp_cfg.get("level", "o2")
207
    amp.init_loss_scaling = amp_cfg.get("scale_loss", 32768)
208
    amp.custom_black_list = amp_cfg.get("custom_black_list", [])
209
    amp.custom_white_list = amp_cfg.get("custom_white_list", [])
210
    amp.use_fp16_guard = amp_cfg.get("use_fp16_guard", False)
211
    amp.use_bf16_guard = amp_cfg.get("use_bf16_guard", False)
212

213
    # mp_optimization config
214
    mp_degree = config.Distributed.get("mp_degree", 1)
215
    if mp_degree > 1:
216
        mp_cfg = config.Distributed.get("mp_optimization", {})
217
        strategy.mp_optimization.allreduce_matmul_grad_overlapping = mp_cfg.get("allreduce_matmul_grad_overlapping", False)
218

219
    # sharding config
220
    sharding_cfg = config.Distributed.get("sharding", {})
221
    sharding = strategy.sharding
222
    sharding.enable = sharding_cfg.get("sharding_degree", 1) > 1
223
    sharding.degree = sharding_cfg.get("sharding_degree", 1)
224
    sharding.stage = sharding_cfg.get("sharding_stage", 1)
225
    sharding.enable_overlap = sharding_cfg.get("reduce_overlap", False) and sharding_cfg.get("broadcast_overlap", False)
226
    sharding.param_comm_stream_num = sharding_cfg.get("param_comm_stream_num", 1)
227
    sharding.grad_comm_stream_num = sharding_cfg.get("grad_comm_stream_num", 1)
228
    sharding.param_bucket_size_numel = sharding_cfg.get("param_bucket_size_numel", 1)
229
    sharding.grad_bucket_size_numel = sharding_cfg.get("grad_bucket_size_numel", 1)
230
    sharding.enable_hierarchical_comm = sharding_cfg.get("enable_hierarchical_comm", False)
231

232
    pp_degree = config["Distributed"]["pp_degree"]
233
    accumulate_steps = config.Engine.get("accumulate_steps", 1)
234
    if pp_degree > 1 and accumulate_steps > 1:
235
        # pipeline config
236
        pipeline_cfg = config.Distributed.get("pipeline", {})
237
        pipeline = strategy.pipeline
238
        pipeline.enable = True
239
        pipeline.enable_send_recv_overlap = pipeline_cfg.get("enable_send_recv_overlap", False)
240
        pipeline.schedule_mode = pipeline_cfg.get("schedule_mode", "1F1B")
241
        pipeline.micro_batch_size = config.Global.micro_batch_size
242
        pipeline.accumulate_steps = accumulate_steps
243
        pipeline.job_schedule_profiler_start = pipeline_cfg.get("job_schedule_profiler_start", -1)
244
        pipeline.job_schedule_profiler_stop = pipeline_cfg.get("job_schedule_profiler_stop", -1)
245
        
246
    elif accumulate_steps > 1:
247
        # gradient merge config
248
        gradient_merge = strategy.gradient_merge
249
        gradient_merge.enable = True
250
        gradient_merge.k_steps = accumulate_steps
251

252
    # quantization config
253
    qat_cfg = config.get("Quantization", {})
254
    qat = strategy.qat
255
    qat.enable = qat_cfg.get("enable", False)
256
    qat.channel_wise_abs_max = qat_cfg.get("channel_wise_abs_max", True)
257
    qat.weight_bits = qat_cfg.get("weight_bits", 8)
258
    qat.activation_bits = qat_cfg.get("activation_bits", 8)
259
    qat.onnx_format = qat_cfg.get("onnx_format", True)
260

261
    # tuning config
262
    tuning_cfg = config.get("Tuning", {})
263
    tuning = strategy.tuning
264
    tuning.enable = tuning_cfg.get("enable", False)
265
    tuning.profile_start_step = tuning_cfg.get("profile_start_step", 1)
266
    tuning.profile_end_step = tuning_cfg.get("profile_end_step", 1)
267
    tuning.run_after_tuning = tuning_cfg.get("run_after_tuning", True)
268
    tuning.debug = tuning_cfg.get("debug", True)
269

270
    # sequence parallel config
271
    if config.Model.get("sequence_parallel", False):
272
        sp_optimization = strategy.sp_optimization
273
        sp_optimization.enable = True
274

275
    engine_cfg = config["Engine"]
276
    engine_cfg["strategy"] = strategy
277

278

279
def process_ckpt_dir(config):
280
    configs = config["Engine"]["save_load"]
281
    ckpt_dir = configs.get("ckpt_dir", None)
282
    if ckpt_dir is None:
283
        return
284

285
    assert (
286
        os.path.isdir(ckpt_dir) is False
287
    ), "Wrong setting of ckpt_dir!ckpt_dir can't be a folder, but {} is a folder. Your `ckpt_dir` should be `dirname/prefix` like `output/auto` if your model path is `output/auto_dist0.pdparams`".format(
288
        ckpt_dir
289
    )
290

291
    assert os.path.exists(ckpt_dir) is False, (
292
        "Wrong setting of ckpt_dir,"
293
        "if you want to load weight,you should set ckpt_dir like this!"
294
        "for example:\ngpt_auto_model_save\n\t--auto_dist0.pdparams\n\t--auto_dist0.pdparams\n"
295
        '\t--auto_dist0.pdattr\nyou should set ckpt_dir="gpt_auto_model_save/auto"'
296
    )
297

298
    parent_path = os.path.split(ckpt_dir)[0]
299

300
    if os.path.exists(parent_path) is False:
301
        logger.warning("{} path is not existed!we will set ckpt_dir None.".format(parent_path))
302
        configs["ckpt_dir"] is None
303

304

305
def get_config(fname, overrides=None, show=False):
306
    """
307
    Read config from file for auto parallel
308
    """
309
    assert os.path.exists(fname), "config file({}) is not exist".format(fname)
310
    config = parse_config(fname)
311
    override_config(config, overrides)
312

313
    process_dist_configs(config)
314
    process_global_configs(config)
315
    process_engine_configs(config)
316
    process_strategy(config)
317
    process_ckpt_dir(config)
318
    create_attr_dict(AttrDict(config))
319

320
    if show:
321
        print_config(config)
322
    check_config(config)
323
    return config
324

325

326
def parse_args():
327
    parser = argparse.ArgumentParser("train script")
328
    parser.add_argument("-c", "--config", type=str, default="configs/config.yaml", help="config file path")
329
    parser.add_argument("-o", "--override", action="append", default=[], help="config options to be overridden")
330
    args = parser.parse_args()
331
    return args
332

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.