paddlenlp

Форк
0
133 строки · 4.8 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16

17
import paddle
18
import paddle.distributed as dist
19
from paddle.incubate.distributed.utils.io import save_for_auto_inference
20
from ppfleetx.distributed.apis import env
21
from ppfleetx.utils.log import logger
22

23

24
def save(output_dir, model, optimizer=None, step=0, epoch=0, sharding_stage=2):
25
    """
26
    save the state dicts of model and optimizer into an checkpoint.
27
    """
28

29
    nranks = dist.get_world_size()
30
    if nranks > 1:
31
        hcg = env.get_hcg()
32

33
        dp_rank = hcg.get_data_parallel_rank()
34
        mp_rank = hcg.get_model_parallel_rank()
35
        pp_rank = hcg.get_stage_id()
36
        sharding_rank = hcg.get_sharding_parallel_rank()
37
    else:
38
        dp_rank = 0
39

40
    if dp_rank != 0:
41
        logger.info("DP_Rank %d doesn't save model" % dp_rank)
42
        return
43

44
    if output_dir and isinstance(output_dir, str):
45
        output_dir = os.path.join(output_dir, "epoch_%d_step_%d" % (epoch, step))
46

47
        if not os.path.exists(output_dir):
48
            os.makedirs(output_dir, exist_ok=True)
49
        logger.info("Save model to %s" % output_dir)
50

51
        save_dir = (
52
            "{}/mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}".format(output_dir, mp_rank, sharding_rank, pp_rank)
53
            if nranks > 1
54
            else output_dir
55
        )
56

57
        if sharding_stage == 3:
58
            model.get_all_parameters(convert2cpu=False)
59

60
        paddle.save(model.state_dict(), os.path.join(save_dir, "model.pdparams"))
61

62
        if optimizer is not None:
63
            paddle.save(optimizer.state_dict(), os.path.join(save_dir, "model_state.pdopt"))
64

65
        meta_dict = {"epoch": epoch, "step": step, "cuda_rng_state": paddle.get_cuda_rng_state()}
66
        paddle.save(meta_dict, os.path.join(save_dir, "meta_state.pdopt"))
67

68
        save_auto_dir = os.path.join(output_dir, "auto_infer")
69
        save_for_auto_inference(os.path.join(save_auto_dir, "auto"), model)
70

71
    else:
72
        raise TypeError("`save` requires a valid value of `output_dir`.")
73

74

75
def load(ckpt_dir, model, optimizer=None, mode="train", load_recovery=None):
76
    nranks = dist.get_world_size()
77
    if nranks > 1:
78
        hcg = env.get_hcg()
79

80
        mp_rank = hcg.get_model_parallel_rank()
81
        pp_rank = hcg.get_stage_id()
82
        sharding_rank = hcg.get_sharding_parallel_rank()
83

84
    load_recovery = {} if load_recovery is None else load_recovery
85

86
    if ckpt_dir and isinstance(ckpt_dir, str):
87
        logger.info("Try to load checkpoint from %s " % ckpt_dir)
88

89
        if mode == "quant":
90
            load_dir = ckpt_dir
91
        else:
92
            load_dir = (
93
                "{}/mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}".format(ckpt_dir, mp_rank, sharding_rank, pp_rank)
94
                if nranks > 1
95
                else ckpt_dir
96
            )
97
        model_path = os.path.join(load_dir, "model.pdparams")
98
        opt_path = os.path.join(load_dir, "model_state.pdopt")
99
        meta_path = os.path.join(load_dir, "meta_state.pdopt")
100

101
        if os.path.exists(model_path):
102
            model_dict = paddle.load(model_path)
103
            for name, param in model.state_dict().items():
104
                assert name in model_dict.keys(), "No param named `{}` was found in checkpoint file.".format(name)
105

106
                if param.dtype != model_dict[name].dtype:
107
                    model_dict[name] = model_dict[name].cast(param.dtype)
108

109
            model.set_state_dict(model_dict)
110
        else:
111
            raise ValueError("No model checkpoint file found in %s." % model_path)
112

113
        if mode == "train":
114
            if os.path.exists(opt_path):
115
                opt_dict = paddle.load(opt_path)
116
                optimizer.set_state_dict(opt_dict)
117
            else:
118
                raise ValueError("No optimizer checkpoint file found in %s." % opt_path)
119

120
            if os.path.exists(meta_path):
121
                meta_dict = paddle.load(meta_path)
122

123
                load_recovery.update(
124
                    {"step": meta_dict["step"], "epoch": meta_dict["epoch"], "rng_state": meta_dict["cuda_rng_state"]}
125
                )
126

127
            else:
128
                raise ValueError("No meta checkpoint file found in %s." % meta_path)
129

130
        logger.info("successfully load checkpoints")
131
    else:
132
        logger.warning("`load` requires a valid value of `ckpt_dir`.")
133
        raise TypeError("`load` requires a valid value of `ckpt_dir`.")
134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.