paddlenlp

Форк
0
64 строки · 2.1 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15

16
def process_optim_configs(config):
17
    """
18
    process optim configs for auto parallel
19
    """
20
    config["Optimizer"]["lr"]["decay_steps"] *= config["Global"]["global_batch_size"]
21

22

23
def process_model_configs(config):
24
    """
25
    process model configs for auto parallel
26
    """
27
    cfg_model = config["Model"]
28
    if cfg_model["ffn_hidden_size"] is None:
29
        cfg_model["ffn_hidden_size"] = 4 * cfg_model["hidden_size"]
30

31
    if cfg_model["use_recompute"]:
32
        if not cfg_model.get("recompute_granularity", None):
33
            cfg_model["recompute_granularity"] = "full"
34

35

36
def process_data_configs(config):
37
    """
38
    process data configs for auto parallel
39
    """
40
    cfg_global = config["Global"]
41
    cfg_data = config["Data"]
42

43
    mode_to_num_samples = {
44
        "Train": cfg_global["global_batch_size"] * config["Engine"]["max_steps"],
45
        "Eval": cfg_global["global_batch_size"]
46
        * (config["Engine"]["max_steps"] // config["Engine"]["eval_freq"] + 1)
47
        * config["Engine"]["eval_iters"],
48
        "Test": cfg_global["global_batch_size"] * config["Engine"]["test_iters"],
49
    }
50

51
    for mode in ("Train", "Eval", "Test"):
52
        if mode in cfg_data.keys():
53
            cfg_data[mode]["dataset"]["num_samples"] = mode_to_num_samples[mode]
54
            cfg_data[mode]["dataset"]["mode"] = mode
55
            cfg_data[mode]["dataset"]["seed"] = cfg_global["seed"]
56

57

58
def process_configs(config):
59

60
    process_model_configs(config)
61
    process_data_configs(config)
62
    process_optim_configs(config)
63

64
    return config
65

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.