20
import paddle.distributed as dist
21
import paddle.distributed.auto_parallel as auto
31
from .log import logger
34
def process_dist_configs(config):
36
process distributed strategy for auto parallel
38
nranks = dist.get_world_size()
40
configs = config["Distributed"]
42
mp_degree = configs.setdefault("mp_degree", 1)
43
pp_degree = configs.setdefault("pp_degree", 1)
46
sequence_parallel = config["Model"]["sequence_parallel"]
47
if mp_degree < 2 and sequence_parallel:
48
config["Model"]["sequence_parallel"] = False
50
"sequence_parallel is turn off since mp_degree < 2."
55
sharding_config = configs["sharding"]
56
sharding_degree = sharding_config.setdefault("sharding_degree", 1)
57
sharding_config.setdefault("sharding_stage", 2)
58
sharding_config.setdefault("reduce_overlap", False)
59
sharding_config.setdefault("broadcast_overlap", False)
61
other_degree = mp_degree * pp_degree
63
assert nranks % other_degree == 0, "Requires nranks should be divided by mp_degree*pp_degree."
64
dp_degree = configs.setdefault("dp_degree", nranks // other_degree)
65
assert nranks % dp_degree == 0, "unreasonable config of dist_strategy."
66
assert nranks == dp_degree * other_degree, (
67
"Mismatched config using {} cards with dp_degree[{}],"
68
"mp_degree[{}], pp_degree[{}] and sharding_degree[{}]".format(
69
nranks, dp_degree, mp_degree, pp_degree, sharding_degree
74
def process_global_configs(config):
76
process global configs for auto parallel
78
dp_degree = config["Distributed"]["dp_degree"]
92
global_cfg = config["Global"]
95
flags = global_cfg.get("flags", {})
96
paddle.set_flags(flags)
97
for k, v in flags.items():
98
logger.info("Environment variable {} is set {}.".format(k, v))
100
if global_cfg["global_batch_size"] is None and global_cfg["local_batch_size"] is None:
101
raise ValueError("global_batch_size or local_batch_size should be set.")
102
elif global_cfg["global_batch_size"] is not None and global_cfg["local_batch_size"] is not None:
104
global_cfg["global_batch_size"] // global_cfg["local_batch_size"] == dp_degree
105
), "global_batch_size[{}] should be divided by local_batch_size[{}] when dp_degree is [{}]".format(
106
global_cfg["global_batch_size"], global_cfg["local_batch_size"], dp_degree
108
elif global_cfg["global_batch_size"] is not None and global_cfg["local_batch_size"] is None:
110
global_cfg["global_batch_size"] % dp_degree == 0
111
), "global_batch_size[{}] should be divided by dp_degree[{}]".format(
112
global_cfg["global_batch_size"], dp_degree
114
global_cfg["local_batch_size"] = global_cfg["global_batch_size"] // dp_degree
116
global_cfg["global_batch_size"] = global_cfg["local_batch_size"] * dp_degree
117
assert global_cfg["local_batch_size"] % global_cfg["micro_batch_size"] == 0
120
def process_engine_configs(config):
122
process engine configs for auto parallel
124
if config.Engine.get("verbose", None) is None:
125
config.Engine["verbose"] = 2
126
if config.Engine.get("logging_freq", None) is None:
127
config.Engine["logging_freq"] = 10
128
config.Engine["save_load"] = config.Engine.get("save_load", {})
129
save_load_cfg = config.Engine.save_load
130
save_steps = save_load_cfg.get("save_steps", None)
131
save_epoch = save_load_cfg.get("save_epoch", None)
132
if save_steps is None or save_steps == -1:
133
save_load_cfg["save_steps"] = sys.maxsize if sys.version > "3" else sys.maxint
135
if save_epoch is None or save_epoch == -1:
136
save_load_cfg["save_epoch"] = 1
138
save_load_cfg["output_dir"] = save_load_cfg.get("output_dir", "./output")
139
save_load_cfg["ckpt_dir"] = save_load_cfg.get("ckpt_dir", None)
141
config.Engine["max_steps"] = config.Engine.get("max_steps", 500000)
142
config.Engine["eval_freq"] = config.Engine.get("eval_freq", -1)
143
config.Engine["eval_iters"] = config.Engine.get("eval_iters", 0)
144
config.Engine["logging_freq"] = config.Engine.get("logging_freq", 1)
145
config.Engine["num_train_epochs"] = config.Engine.get("num_train_epochs", 1)
146
config.Engine["test_iters"] = (
147
config.Engine["eval_iters"] * 10
148
if config.Engine.get("test_iters", None) is None
149
else config.Engine["test_iters"]
151
config.Engine["accumulate_steps"] = config.Global.local_batch_size // config.Global.micro_batch_size
154
is_pir_mode = os.environ.get("FLAGS_enable_pir_in_executor", None)
155
return str(is_pir_mode).lower() not in ('false', 'off', '0', 'none')
157
def process_strategy(config):
159
process auto strategy for auto parallel
161
strategy = auto.Strategy()
162
strategy.auto_mode = "semi"
165
if config.get("FusedPasses", None) is not None:
167
fused_passes_list = []
168
fused_linear = config.FusedPasses.pop("fused_linear", False)
169
fused_adamw = config.FusedPasses.pop("fused_adamw", False)
172
fused_passes_list.append("fused_gemm_epilogue_pass")
174
fused_passes_list.append("fuse_gemm_epilogue")
176
fused_passes_list.append("fuse_adamw")
177
fused_passes = strategy.fused_passes
178
fused_passes.enable = len(fused_passes_list) > 0
179
fused_passes.fused_passes_list = fused_passes_list
181
if config.get("Model", None) is not None:
183
if not config.Model.get("no_recompute_layers", None):
184
config.Model["no_recompute_layers"] = []
186
assert isinstance(config.Model["no_recompute_layers"], list), "no_recompute_layers should be a list"
187
for i in config.Model["no_recompute_layers"]:
188
assert isinstance(i, int), "all values in no_recompute_layers should be an integer"
189
assert min(config.Model["no_recompute_layers"]) >= 0, "the min value in no_recompute_layers should >= 0"
191
max(config.Model["no_recompute_layers"]) < config.Model["num_layers"]
192
), "the max value in no_recompute_layers should < num_layers"
193
config.Model["no_recompute_layers"] = sorted(list(set(config.Model["no_recompute_layers"])))
194
recompute = strategy.recompute
195
recompute.enable = config.Model.get("use_recompute", False)
196
recompute.sr = config.Model.pop("sr", 0)
197
recompute.refined_ops_patterns = config.Model.pop("refined_ops_patterns", [])
198
recompute.no_recompute_segments = config.Model.pop("no_recompute_layers", [])
199
recompute.enable_tuning = config.get("Tuning", False) and config.Tuning.get("tuning_recompute", False)
202
amp_cfg = config.Engine.get("mix_precision", {})
204
amp.enable = amp_cfg.get("enable", False)
205
amp.dtype = amp_cfg.get("dtype", "float16")
206
amp.level = amp_cfg.get("level", "o2")
207
amp.init_loss_scaling = amp_cfg.get("scale_loss", 32768)
208
amp.custom_black_list = amp_cfg.get("custom_black_list", [])
209
amp.custom_white_list = amp_cfg.get("custom_white_list", [])
210
amp.use_fp16_guard = amp_cfg.get("use_fp16_guard", False)
211
amp.use_bf16_guard = amp_cfg.get("use_bf16_guard", False)
214
mp_degree = config.Distributed.get("mp_degree", 1)
216
mp_cfg = config.Distributed.get("mp_optimization", {})
217
strategy.mp_optimization.allreduce_matmul_grad_overlapping = mp_cfg.get("allreduce_matmul_grad_overlapping", False)
220
sharding_cfg = config.Distributed.get("sharding", {})
221
sharding = strategy.sharding
222
sharding.enable = sharding_cfg.get("sharding_degree", 1) > 1
223
sharding.degree = sharding_cfg.get("sharding_degree", 1)
224
sharding.stage = sharding_cfg.get("sharding_stage", 1)
225
sharding.enable_overlap = sharding_cfg.get("reduce_overlap", False) and sharding_cfg.get("broadcast_overlap", False)
226
sharding.param_comm_stream_num = sharding_cfg.get("param_comm_stream_num", 1)
227
sharding.grad_comm_stream_num = sharding_cfg.get("grad_comm_stream_num", 1)
228
sharding.param_bucket_size_numel = sharding_cfg.get("param_bucket_size_numel", 1)
229
sharding.grad_bucket_size_numel = sharding_cfg.get("grad_bucket_size_numel", 1)
230
sharding.enable_hierarchical_comm = sharding_cfg.get("enable_hierarchical_comm", False)
232
pp_degree = config["Distributed"]["pp_degree"]
233
accumulate_steps = config.Engine.get("accumulate_steps", 1)
234
if pp_degree > 1 and accumulate_steps > 1:
236
pipeline_cfg = config.Distributed.get("pipeline", {})
237
pipeline = strategy.pipeline
238
pipeline.enable = True
239
pipeline.enable_send_recv_overlap = pipeline_cfg.get("enable_send_recv_overlap", False)
240
pipeline.schedule_mode = pipeline_cfg.get("schedule_mode", "1F1B")
241
pipeline.micro_batch_size = config.Global.micro_batch_size
242
pipeline.accumulate_steps = accumulate_steps
243
pipeline.job_schedule_profiler_start = pipeline_cfg.get("job_schedule_profiler_start", -1)
244
pipeline.job_schedule_profiler_stop = pipeline_cfg.get("job_schedule_profiler_stop", -1)
246
elif accumulate_steps > 1:
248
gradient_merge = strategy.gradient_merge
249
gradient_merge.enable = True
250
gradient_merge.k_steps = accumulate_steps
253
qat_cfg = config.get("Quantization", {})
255
qat.enable = qat_cfg.get("enable", False)
256
qat.channel_wise_abs_max = qat_cfg.get("channel_wise_abs_max", True)
257
qat.weight_bits = qat_cfg.get("weight_bits", 8)
258
qat.activation_bits = qat_cfg.get("activation_bits", 8)
259
qat.onnx_format = qat_cfg.get("onnx_format", True)
262
tuning_cfg = config.get("Tuning", {})
263
tuning = strategy.tuning
264
tuning.enable = tuning_cfg.get("enable", False)
265
tuning.profile_start_step = tuning_cfg.get("profile_start_step", 1)
266
tuning.profile_end_step = tuning_cfg.get("profile_end_step", 1)
267
tuning.run_after_tuning = tuning_cfg.get("run_after_tuning", True)
268
tuning.debug = tuning_cfg.get("debug", True)
271
if config.Model.get("sequence_parallel", False):
272
sp_optimization = strategy.sp_optimization
273
sp_optimization.enable = True
275
engine_cfg = config["Engine"]
276
engine_cfg["strategy"] = strategy
279
def process_ckpt_dir(config):
280
configs = config["Engine"]["save_load"]
281
ckpt_dir = configs.get("ckpt_dir", None)
286
os.path.isdir(ckpt_dir) is False
287
), "Wrong setting of ckpt_dir!ckpt_dir can't be a folder, but {} is a folder. Your `ckpt_dir` should be `dirname/prefix` like `output/auto` if your model path is `output/auto_dist0.pdparams`".format(
291
assert os.path.exists(ckpt_dir) is False, (
292
"Wrong setting of ckpt_dir,"
293
"if you want to load weight,you should set ckpt_dir like this!"
294
"for example:\ngpt_auto_model_save\n\t--auto_dist0.pdparams\n\t--auto_dist0.pdparams\n"
295
'\t--auto_dist0.pdattr\nyou should set ckpt_dir="gpt_auto_model_save/auto"'
298
parent_path = os.path.split(ckpt_dir)[0]
300
if os.path.exists(parent_path) is False:
301
logger.warning("{} path is not existed!we will set ckpt_dir None.".format(parent_path))
302
configs["ckpt_dir"] is None
305
def get_config(fname, overrides=None, show=False):
307
Read config from file for auto parallel
309
assert os.path.exists(fname), "config file({}) is not exist".format(fname)
310
config = parse_config(fname)
311
override_config(config, overrides)
313
process_dist_configs(config)
314
process_global_configs(config)
315
process_engine_configs(config)
316
process_strategy(config)
317
process_ckpt_dir(config)
318
create_attr_dict(AttrDict(config))
327
parser = argparse.ArgumentParser("train script")
328
parser.add_argument("-c", "--config", type=str, default="configs/config.yaml", help="config file path")
329
parser.add_argument("-o", "--override", action="append", default=[], help="config options to be overridden")
330
args = parser.parse_args()