1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
18
import paddle.distributed as dist
19
from paddle.incubate.distributed.utils.io import save_for_auto_inference
20
from ppfleetx.distributed.apis import env
21
from ppfleetx.utils.log import logger
24
def save(output_dir, model, optimizer=None, step=0, epoch=0, sharding_stage=2):
26
save the state dicts of model and optimizer into an checkpoint.
29
nranks = dist.get_world_size()
33
dp_rank = hcg.get_data_parallel_rank()
34
mp_rank = hcg.get_model_parallel_rank()
35
pp_rank = hcg.get_stage_id()
36
sharding_rank = hcg.get_sharding_parallel_rank()
41
logger.info("DP_Rank %d doesn't save model" % dp_rank)
44
if output_dir and isinstance(output_dir, str):
45
output_dir = os.path.join(output_dir, "epoch_%d_step_%d" % (epoch, step))
47
if not os.path.exists(output_dir):
48
os.makedirs(output_dir, exist_ok=True)
49
logger.info("Save model to %s" % output_dir)
52
"{}/mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}".format(output_dir, mp_rank, sharding_rank, pp_rank)
57
if sharding_stage == 3:
58
model.get_all_parameters(convert2cpu=False)
60
paddle.save(model.state_dict(), os.path.join(save_dir, "model.pdparams"))
62
if optimizer is not None:
63
paddle.save(optimizer.state_dict(), os.path.join(save_dir, "model_state.pdopt"))
65
meta_dict = {"epoch": epoch, "step": step, "cuda_rng_state": paddle.get_cuda_rng_state()}
66
paddle.save(meta_dict, os.path.join(save_dir, "meta_state.pdopt"))
68
save_auto_dir = os.path.join(output_dir, "auto_infer")
69
save_for_auto_inference(os.path.join(save_auto_dir, "auto"), model)
72
raise TypeError("`save` requires a valid value of `output_dir`.")
75
def load(ckpt_dir, model, optimizer=None, mode="train", load_recovery=None):
76
nranks = dist.get_world_size()
80
mp_rank = hcg.get_model_parallel_rank()
81
pp_rank = hcg.get_stage_id()
82
sharding_rank = hcg.get_sharding_parallel_rank()
84
load_recovery = {} if load_recovery is None else load_recovery
86
if ckpt_dir and isinstance(ckpt_dir, str):
87
logger.info("Try to load checkpoint from %s " % ckpt_dir)
93
"{}/mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}".format(ckpt_dir, mp_rank, sharding_rank, pp_rank)
97
model_path = os.path.join(load_dir, "model.pdparams")
98
opt_path = os.path.join(load_dir, "model_state.pdopt")
99
meta_path = os.path.join(load_dir, "meta_state.pdopt")
101
if os.path.exists(model_path):
102
model_dict = paddle.load(model_path)
103
for name, param in model.state_dict().items():
104
assert name in model_dict.keys(), "No param named `{}` was found in checkpoint file.".format(name)
106
if param.dtype != model_dict[name].dtype:
107
model_dict[name] = model_dict[name].cast(param.dtype)
109
model.set_state_dict(model_dict)
111
raise ValueError("No model checkpoint file found in %s." % model_path)
114
if os.path.exists(opt_path):
115
opt_dict = paddle.load(opt_path)
116
optimizer.set_state_dict(opt_dict)
118
raise ValueError("No optimizer checkpoint file found in %s." % opt_path)
120
if os.path.exists(meta_path):
121
meta_dict = paddle.load(meta_path)
123
load_recovery.update(
124
{"step": meta_dict["step"], "epoch": meta_dict["epoch"], "rng_state": meta_dict["cuda_rng_state"]}
128
raise ValueError("No meta checkpoint file found in %s." % meta_path)
130
logger.info("successfully load checkpoints")
132
logger.warning("`load` requires a valid value of `ckpt_dir`.")
133
raise TypeError("`load` requires a valid value of `ckpt_dir`.")