paddlenlp

Форк
0
410 строк · 14.5 Кб
1
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import codecs
17
import copy
18
import os
19
import sys
20

21
import paddle
22
import paddle.distributed as dist
23
import yaml
24
from paddle.base.reader import use_pinned_memory
25

26
from . import check
27
from .log import advertise, logger
28

29
__all__ = ["get_config", "print_config"]
30

31

32
def process_dist_config(configs):
33
    """
34
    process distributed strategy for hybrid parallel
35
    """
36
    nranks = dist.get_world_size()
37

38
    config = configs["Distributed"]
39

40
    config.setdefault("hcg", "HybridCommunicateGroup")
41
    mp_degree = config.setdefault("mp_degree", 1)
42
    pp_degree = config.setdefault("pp_degree", 1)
43
    config.setdefault("pp_recompute_interval", 1)
44
    sep_degree = config.setdefault("sep_degree", 1)
45

46
    # sharding default
47
    sharding_config = config["sharding"]
48
    sharding_degree = sharding_config.setdefault("sharding_degree", 1)
49
    sharding_config.setdefault("sharding_stage", 2)
50
    sharding_config.setdefault("sharding_offload", False)
51
    reduce_overlap = sharding_config.setdefault("reduce_overlap", False)
52
    broadcast_overlap = sharding_config.setdefault("broadcast_overlap", False)
53

54
    other_degree = sep_degree * mp_degree * pp_degree * sharding_degree
55

56
    assert nranks % other_degree == 0, "unreasonable config of dist_strategy."
57
    dp_degree = config.setdefault("dp_degree", nranks // other_degree)
58
    assert nranks % dp_degree == 0, "unreasonable config of dist_strategy."
59
    assert nranks == dp_degree * other_degree, (
60
        "Mismatched config using {} cards with dp_degree[{}],"
61
        "sep_degree[{}], mp_degree[{}], pp_degree[{}] and sharding_degree[{}]".format(
62
            nranks, dp_degree, sep_degree, mp_degree, pp_degree, sharding_degree
63
        )
64
    )
65

66
    if sharding_config["sharding_degree"] > 1 and reduce_overlap:
67
        if sharding_config["sharding_stage"] == 3 or sharding_config["sharding_offload"]:
68
            sharding_config["reduce_overlap"] = False
69
            logger.warning("reduce overlap only valid for sharding stage 2 without offload")
70

71
    if sharding_config["sharding_degree"] > 1 and broadcast_overlap:
72
        if sharding_config["sharding_stage"] == 3 or sharding_config["sharding_offload"]:
73
            sharding_config["broadcast_overlap"] = False
74
            logger.warning("broadcast overlap only valid for sharding stage 2 without offload")
75

76
    if broadcast_overlap and configs["Engine"]["logging_freq"] == 1:
77
        logger.warning(
78
            "Set logging_freq to 1 will disable broadcast_overlap. "
79
            "If you want to overlap the broadcast, please increase the logging_freq."
80
        )
81
        sharding_config["broadcast_overlap"] = False
82

83
    if sharding_config["sharding_degree"] > 1:
84
        if getattr(sharding_config, "broadcast_overlap", False):
85
            logger.warning("Enable broadcast overlap for sharding will not use pin memory for dataloader")
86
            use_pinned_memory(False)
87

88
    if "fuse_sequence_parallel_allreduce" not in config:
89
        config["fuse_sequence_parallel_allreduce"] = False
90

91
    if "use_main_grad" in config and config["use_main_grad"] is True:
92
        logger.warning("If use_main_grad is True, fuse_sequence_parallel_allreduce will be forced to False")
93
        config["fuse_sequence_parallel_allreduce"] = False
94

95

96
def process_global_configs(config):
97
    """
98
    process global configs for hybrid parallel
99
    """
100
    dp_degree = config["Distributed"]["dp_degree"]
101
    pp_degree = config["Distributed"]["pp_degree"]
102
    sharding_degree = config["Distributed"]["sharding"]["sharding_degree"]
103

104
    config["Global"]["enable_partial_send_recv"] = config["Global"]["enable_partial_send_recv"] if "enable_partial_send_recv" in config["Global"] else True
105
    if "sequence_parallel" in config["Model"] and pp_degree > 1:
106
        if config["Model"]["sequence_parallel"]:
107
            config["Global"]["enable_partial_send_recv"] = False
108
            logger.warning(
109
                "if config.Distributed.pp_degree > 1 and config.Model.sequence_parallel is True, "
110
                "config.Global.enable_partial_send_recv will be set False."
111
            )
112

113
    global_cfg = config["Global"]
114

115
    # Set environment variable
116
    flags = global_cfg.get("flags", {})
117
    paddle.set_flags(flags)
118
    for k, v in flags.items():
119
        logger.info("Environment variable {} is set {}.".format(k, v))
120

121
    if global_cfg["global_batch_size"] is None and global_cfg["local_batch_size"] is None:
122
        raise ValueError("global_batch_size or local_batch_size should be set.")
123
    elif global_cfg["global_batch_size"] is not None and global_cfg["local_batch_size"] is not None:
124
        assert global_cfg["global_batch_size"] // global_cfg["local_batch_size"] == (dp_degree * sharding_degree), (
125
            "global_batch_size[{}] should be divided by local_batch_size[{}] "
126
            "when dp_degree is [{}] and sharding_degree is [{}]".format(
127
                global_cfg["global_batch_size"], global_cfg["local_batch_size"], dp_degree, sharding_degree
128
            )
129
        )
130
    elif global_cfg["global_batch_size"] is not None and global_cfg["local_batch_size"] is None:
131
        assert (
132
            global_cfg["global_batch_size"] % (dp_degree * sharding_degree) == 0
133
        ), "global_batch_size[{}] should be divided by dp_degree[{}] times sharding_degree[{}]".format(
134
            global_cfg["global_batch_size"], dp_degree, sharding_degree
135
        )
136
        global_cfg["local_batch_size"] = global_cfg["global_batch_size"] // (dp_degree * sharding_degree)
137
    else:
138
        global_cfg["global_batch_size"] = global_cfg["local_batch_size"] * dp_degree * sharding_degree
139
    assert global_cfg["local_batch_size"] % global_cfg["micro_batch_size"] == 0
140

141

142
def process_engine_config(config):
143
    """
144
    process engine
145
    """
146
    # save_load
147
    config.Engine["save_load"] = config.Engine.get("save_load", {})
148
    save_load_cfg = config.Engine.save_load
149
    save_steps = save_load_cfg.get("save_steps", None)
150
    save_epoch = save_load_cfg.get("save_epoch", None)
151
    if save_steps is None or save_steps == -1:
152
        save_load_cfg["save_steps"] = sys.maxsize if sys.version > "3" else sys.maxint
153

154
    if save_epoch is None or save_epoch == -1:
155
        save_load_cfg["save_epoch"] = 1
156

157
    save_load_cfg["output_dir"] = save_load_cfg.get("output_dir", "./output")
158
    save_load_cfg["ckpt_dir"] = save_load_cfg.get("ckpt_dir", None)
159

160
    # mix_precision
161
    config.Engine["mix_precision"] = config.Engine.get("mix_precision", {})
162
    amp_cfg = config.Engine.mix_precision
163

164
    amp_cfg["enable"] = amp_cfg.get("enable", False)
165
    amp_cfg["scale_loss"] = amp_cfg.get("scale_loss", 32768)
166
    amp_cfg["custom_black_list"] = amp_cfg.get("custom_black_list", None)
167
    amp_cfg["custom_white_list"] = amp_cfg.get("custom_white_list", None)
168

169
    # engine
170
    config.Engine["max_steps"] = config.Engine.get("max_steps", 500000)
171
    config.Engine["eval_freq"] = config.Engine.get("eval_freq", -1)
172
    config.Engine["eval_iters"] = config.Engine.get("eval_iters", 0)
173
    config.Engine["logging_freq"] = config.Engine.get("logging_freq", 1)
174
    config.Engine["num_train_epochs"] = config.Engine.get("num_train_epochs", 1)
175
    config.Engine["test_iters"] = (
176
        config.Engine["eval_iters"] * 10
177
        if config.Engine.get("test_iters", None) is None
178
        else config.Engine["test_iters"]
179
    )
180
    config.Engine["accumulate_steps"] = config.Global.local_batch_size // config.Global.micro_batch_size
181

182

183
class AttrDict(dict):
184
    def __getattr__(self, key):
185
        return self[key]
186

187
    def __setattr__(self, key, value):
188
        if key in self.__dict__:
189
            self.__dict__[key] = value
190
        else:
191
            self[key] = value
192

193
    def __copy__(self):
194
        cls = self.__class__
195
        result = cls.__new__(cls)
196
        result.__dict__.update(self.__dict__)
197
        return result
198

199
    def __deepcopy__(self, memo):
200
        cls = self.__class__
201
        result = cls.__new__(cls)
202
        memo[id(self)] = result
203
        for k, v in self.__dict__.items():
204
            setattr(result, k, copy.deepcopy(v, memo))
205
        for k, v in self.items():
206
            setattr(result, k, copy.deepcopy(v, memo))
207
        return result
208

209
    def setdefault(self, k, default=None):
210
        if k not in self or self[k] is None:
211
            self[k] = default
212
            return default
213
        else:
214
            return self[k]
215

216

217
def create_attr_dict(yaml_config):
218
    from ast import literal_eval
219

220
    for key, value in yaml_config.items():
221
        if type(value) is dict:
222
            yaml_config[key] = value = AttrDict(value)
223
        if isinstance(value, str):
224
            try:
225
                value = literal_eval(value)
226
            except BaseException:
227
                pass
228
        if isinstance(value, AttrDict):
229
            create_attr_dict(yaml_config[key])
230
        else:
231
            yaml_config[key] = value
232

233

234
def parse_config(cfg_file):
235
    """Load a config file into AttrDict"""
236

237
    def _update_dic(dic, base_dic):
238
        """Update config from dic based base_dic"""
239
        base_dic = base_dic.copy()
240
        dic = dic.copy()
241

242
        if dic.get("_inherited_", True) is False:
243
            dic.pop("_inherited_")
244
            return dic
245

246
        for key, val in dic.items():
247
            if isinstance(val, dict) and key in base_dic:
248
                base_dic[key] = _update_dic(val, base_dic[key])
249
            else:
250
                base_dic[key] = val
251
        dic = base_dic
252
        return dic
253

254
    def _parse_from_yaml(path):
255
        """Parse a yaml file and build config"""
256

257
        with codecs.open(path, "r", "utf-8") as file:
258
            dic = yaml.load(file, Loader=yaml.FullLoader)
259

260
        if "_base_" in dic:
261
            cfg_dir = os.path.dirname(path)
262
            base_path = dic.pop("_base_")
263
            base_path = os.path.join(cfg_dir, base_path)
264
            base_dic = _parse_from_yaml(base_path)
265
            dic = _update_dic(dic, base_dic)
266
        return dic
267

268
    yaml_dict = _parse_from_yaml(cfg_file)
269
    yaml_config = AttrDict(yaml_dict)
270

271
    create_attr_dict(yaml_config)
272
    return yaml_config
273

274

275
def print_dict(d, delimiter=0):
276
    """
277
    Recursively visualize a dict and
278
    indenting acrrording by the relationship of keys.
279
    """
280
    placeholder = "-" * 60
281
    for k, v in sorted(d.items()):
282
        if isinstance(v, dict):
283
            logger.info("{}{} : ".format(delimiter * " ", k))
284
            print_dict(v, delimiter + 4)
285
        elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
286
            logger.info("{}{} : ".format(delimiter * " ", k))
287
            for value in v:
288
                print_dict(value, delimiter + 4)
289
        else:
290
            logger.info("{}{} : {}".format(delimiter * " ", k, v))
291
        if k.isupper():
292
            logger.info(placeholder)
293

294

295
def print_config(config):
296
    """
297
    visualize configs
298
    Arguments:
299
        config: configs
300
    """
301
    advertise()
302
    print_dict(config)
303

304

305
def check_config(config):
306
    """
307
    Check config
308
    """
309
    # global_batch_size = config.get("")
310

311
    global_config = config.get("Global")
312
    check.check_version()
313
    device = global_config.get("device", "gpu")
314
    device = device.lower()
315
    if device in ["gpu", "xpu", "rocm", "npu", "cpu"]:
316
        check.check_device(device)
317
    else:
318
        raise ValueError(
319
            f"device({device}) is not in ['gpu', 'xpu', 'rocm', 'npu', 'cpu'],\n"
320
            "Please ensure the config option Global.device is one of these devices"
321
        )
322

323

324
def override(dl, ks, v):
325
    """
326
    Recursively replace dict of list
327
    Args:
328
        dl(dict or list): dict or list to be replaced
329
        ks(list): list of keys
330
        v(str): value to be replaced
331
    """
332

333
    def str2num(v):
334
        try:
335
            return eval(v)
336
        except Exception:
337
            return v
338

339
    assert isinstance(dl, (list, dict)), "{} should be a list or a dict"
340
    assert len(ks) > 0, "lenght of keys should larger than 0"
341
    if isinstance(dl, list):
342
        k = str2num(ks[0])
343
        if len(ks) == 1:
344
            assert k < len(dl), "index({}) out of range({})".format(k, dl)
345
            dl[k] = str2num(v)
346
        else:
347
            override(dl[k], ks[1:], v)
348
    else:
349
        if len(ks) == 1:
350
            # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
351
            if not ks[0] in dl:
352
                print(f"A new field ({ks[0]}) detected!")
353
            dl[ks[0]] = str2num(v)
354
        else:
355
            if ks[0] not in dl.keys():
356
                dl[ks[0]] = {}
357
                print(f"A new Series field ({ks[0]}) detected!")
358
            override(dl[ks[0]], ks[1:], v)
359

360

361
def override_config(config, options=None):
362
    """
363
    Recursively override the config
364
    Args:
365
        config(dict): dict to be replaced
366
        options(list): list of pairs(key0.key1.idx.key2=value)
367
            such as: [
368
                'topk=2',
369
                'VALID.transforms.1.ResizeImage.resize_short=300'
370
            ]
371
    Returns:
372
        config(dict): replaced config
373
    """
374
    if options is not None:
375
        for opt in options:
376
            assert isinstance(opt, str), "option({}) should be a str".format(opt)
377
            assert "=" in opt, "option({}) should contain a =" "to distinguish between key and value".format(opt)
378
            pair = opt.split("=")
379
            assert len(pair) == 2, "there can be only a = in the option"
380
            key, value = pair
381
            keys = key.split(".")
382
            override(config, keys, value)
383
    return config
384

385

386
def get_config(fname, overrides=None, show=False):
387
    """
388
    Read config from file
389
    """
390
    assert os.path.exists(fname), "config file({}) is not exist".format(fname)
391
    config = parse_config(fname)
392
    override_config(config, overrides)
393

394
    process_dist_config(config)
395
    process_global_configs(config)
396
    process_engine_config(config)
397
    create_attr_dict(AttrDict(config))
398

399
    if show:
400
        print_config(config)
401
    check_config(config)
402
    return config
403

404

405
def parse_args():
406
    parser = argparse.ArgumentParser("train script")
407
    parser.add_argument("-c", "--config", type=str, default="configs/config.yaml", help="config file path")
408
    parser.add_argument("-o", "--override", action="append", default=[], help="config options to be overridden")
409
    args = parser.parse_args()
410
    return args
411

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.