16
def process_optim_configs(config):
18
process optim configs for auto parallel
20
config["Optimizer"]["lr"]["decay_steps"] *= config["Global"]["global_batch_size"]
23
def process_model_configs(config):
25
process model configs for auto parallel
27
cfg_model = config["Model"]
28
if cfg_model["ffn_hidden_size"] is None:
29
cfg_model["ffn_hidden_size"] = 4 * cfg_model["hidden_size"]
31
if cfg_model["use_recompute"]:
32
if not cfg_model.get("recompute_granularity", None):
33
cfg_model["recompute_granularity"] = "full"
36
def process_data_configs(config):
38
process data configs for auto parallel
40
cfg_global = config["Global"]
41
cfg_data = config["Data"]
43
mode_to_num_samples = {
44
"Train": cfg_global["global_batch_size"] * config["Engine"]["max_steps"],
45
"Eval": cfg_global["global_batch_size"]
46
* (config["Engine"]["max_steps"] // config["Engine"]["eval_freq"] + 1)
47
* config["Engine"]["eval_iters"],
48
"Test": cfg_global["global_batch_size"] * config["Engine"]["test_iters"],
51
for mode in ("Train", "Eval", "Test"):
52
if mode in cfg_data.keys():
53
cfg_data[mode]["dataset"]["num_samples"] = mode_to_num_samples[mode]
54
cfg_data[mode]["dataset"]["mode"] = mode
55
cfg_data[mode]["dataset"]["seed"] = cfg_global["seed"]
58
def process_configs(config):
60
process_model_configs(config)
61
process_data_configs(config)
62
process_optim_configs(config)