paddlenlp

Форк
0
/
merge_lora_params.py 
147 строк · 5.9 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import argparse
15
import copy
16
import os
17

18
import paddle
19

20
from paddlenlp.peft import LoRAConfig, LoRAModel
21

22
try:
23
    from paddle.nn.quant import weight_dequantize, weight_quantize
24
except:
25
    weight_dequantize = None
26
    weight_quantize = None
27
try:
28
    from paddlenlp.quantization.qlora import qlora_weight_quantize_dequantize
29
except:
30
    qlora_weight_quantize_dequantize = None
31

32
from paddlenlp.quantization.quantization_config import QuantizationConfig
33
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
34
from paddlenlp.transformers.utils import device_guard
35
from paddlenlp.utils.env import CONFIG_NAME
36

37

38
def parse_arguments():
39
    parser = argparse.ArgumentParser()
40
    parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.")
41
    parser.add_argument(
42
        "--lora_path", default=None, required=True, help="The directory of LoRA parameters. Default to None"
43
    )
44
    parser.add_argument(
45
        "--merge_lora_model_path",
46
        default=None,
47
        required=True,
48
        help="The directory of merged parameters. Default to None",
49
    )
50
    parser.add_argument("--device", type=str, default="gpu", help="Device")
51
    parser.add_argument(
52
        "--low_gpu_mem", type=bool, default=False, help="Whether to use low gpu memory. Default to False"
53
    )
54
    return parser.parse_args()
55

56

57
def weight_process(name, quant_config, lora_config, state_dict):
58
    weight = state_dict.pop(name + ".weight").cuda()
59
    if quant_config.weight_quantize_algo is None:
60
        pass
61
    elif quant_config.weight_quantize_algo in ["nf4", "fp4"]:
62
        weight = qlora_weight_quantize_dequantize(
63
            weight,
64
            quant_algo=quant_config.weight_quantize_algo,
65
            double_quant=quant_config.weight_double_quant,
66
            block_size=quant_config.weight_blocksize,
67
            double_quant_block_size=quant_config.weight_double_quant_block_size,
68
        )
69
    elif quant_config.weight_quantize_algo in ["weight_only_int8"]:
70
        out, scale = weight_quantize(weight, algo=quant_config.weight_quantize_algo)
71
        weight = weight_dequantize(out, scale)
72
    else:
73
        raise ValueError(f"quant_config.weight_quantize_algo {quant_config.weight_quantize_algo} is not supported.")
74
    lora_A = state_dict.pop(name + ".lora_A").cuda()
75
    lora_B = state_dict.pop(name + ".lora_B").cuda()
76
    scaling = lora_config.lora_alpha / lora_config.r
77
    state_dict[name + ".weight"] = (weight + lora_A @ lora_B * scaling).cpu()
78

79

80
def merge():
81
    args = parse_arguments()
82
    paddle.set_device(args.device)
83

84
    lora_config = LoRAConfig.from_pretrained(args.lora_path)
85
    if lora_config.base_model_name_or_path is None:
86
        if args.model_name_or_path is not None:
87
            raise ValueError("We can not find a valid model_name_or_path.")
88
        else:
89
            lora_config.base_model_name_or_path = args.model_name_or_path
90

91
    if os.path.isfile(os.path.join(args.lora_path, CONFIG_NAME)):
92
        config = AutoConfig.from_pretrained(args.lora_path)
93
    elif args.model_name_or_path is not None:
94
        config = AutoConfig.from_pretrained(args.model_name_or_path)
95
    else:
96
        raise ValueError(
97
            f"We can not find config.json in lora_path: {args.lora_path} or find a valid model_name_or_path."
98
        )
99
    config.dtype = lora_config.dtype
100
    if (
101
        lora_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
102
    ) and args.device == "cpu":
103
        raise ValueError("We can not apply bfloat16 or nf4/fp4 lora merge on cpu.")
104

105
    if args.low_gpu_mem and args.device == "gpu":
106
        quant_config = copy.deepcopy(config.quantization_config)
107
        config.quantization_config = QuantizationConfig()
108
        lora_config.merge_weights = False
109
        with device_guard():
110
            model = AutoModelForCausalLM.from_pretrained(
111
                lora_config.base_model_name_or_path,
112
                config=config,
113
                low_cpu_mem_usage=True,
114
            )
115
            model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)
116
        model.eval()
117
        model_state_dict = model.model.state_dict()
118
        lora_name_list = []
119
        for key in model_state_dict.keys():
120
            if "lora_A" in key:
121
                lora_name_list.append(key[:-7])
122
        for name in lora_name_list:
123
            weight_process(name, quant_config, lora_config, model_state_dict)
124
    else:
125
        model = AutoModelForCausalLM.from_pretrained(
126
            lora_config.base_model_name_or_path,
127
            config=config,
128
            low_cpu_mem_usage=True,
129
        )
130
        lora_config.merge_weights = True
131
        model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)
132
        model.eval()
133
        model_state_dict = model.model.state_dict()
134
        for key in list(model_state_dict):
135
            if "lora" in key:
136
                del model_state_dict[key]
137
            if "quant" in key:
138
                del model_state_dict[key]
139
        model.model.config.quantization_config = QuantizationConfig()
140
    model.model.save_pretrained(args.merge_lora_model_path, state_dict=model_state_dict)
141

142
    tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path)
143
    tokenizer.save_pretrained(args.merge_lora_model_path)
144

145

146
if __name__ == "__main__":
147
    merge()
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.