20
from paddlenlp.peft import LoRAConfig, LoRAModel
23
from paddle.nn.quant import weight_dequantize, weight_quantize
25
weight_dequantize = None
26
weight_quantize = None
28
from paddlenlp.quantization.qlora import qlora_weight_quantize_dequantize
30
qlora_weight_quantize_dequantize = None
32
from paddlenlp.quantization.quantization_config import QuantizationConfig
33
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
34
from paddlenlp.transformers.utils import device_guard
35
from paddlenlp.utils.env import CONFIG_NAME
39
parser = argparse.ArgumentParser()
40
parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.")
42
"--lora_path", default=None, required=True, help="The directory of LoRA parameters. Default to None"
45
"--merge_lora_model_path",
48
help="The directory of merged parameters. Default to None",
50
parser.add_argument("--device", type=str, default="gpu", help="Device")
52
"--low_gpu_mem", type=bool, default=False, help="Whether to use low gpu memory. Default to False"
54
return parser.parse_args()
57
def weight_process(name, quant_config, lora_config, state_dict):
58
weight = state_dict.pop(name + ".weight").cuda()
59
if quant_config.weight_quantize_algo is None:
61
elif quant_config.weight_quantize_algo in ["nf4", "fp4"]:
62
weight = qlora_weight_quantize_dequantize(
64
quant_algo=quant_config.weight_quantize_algo,
65
double_quant=quant_config.weight_double_quant,
66
block_size=quant_config.weight_blocksize,
67
double_quant_block_size=quant_config.weight_double_quant_block_size,
69
elif quant_config.weight_quantize_algo in ["weight_only_int8"]:
70
out, scale = weight_quantize(weight, algo=quant_config.weight_quantize_algo)
71
weight = weight_dequantize(out, scale)
73
raise ValueError(f"quant_config.weight_quantize_algo {quant_config.weight_quantize_algo} is not supported.")
74
lora_A = state_dict.pop(name + ".lora_A").cuda()
75
lora_B = state_dict.pop(name + ".lora_B").cuda()
76
scaling = lora_config.lora_alpha / lora_config.r
77
state_dict[name + ".weight"] = (weight + lora_A @ lora_B * scaling).cpu()
81
args = parse_arguments()
82
paddle.set_device(args.device)
84
lora_config = LoRAConfig.from_pretrained(args.lora_path)
85
if lora_config.base_model_name_or_path is None:
86
if args.model_name_or_path is not None:
87
raise ValueError("We can not find a valid model_name_or_path.")
89
lora_config.base_model_name_or_path = args.model_name_or_path
91
if os.path.isfile(os.path.join(args.lora_path, CONFIG_NAME)):
92
config = AutoConfig.from_pretrained(args.lora_path)
93
elif args.model_name_or_path is not None:
94
config = AutoConfig.from_pretrained(args.model_name_or_path)
97
f"We can not find config.json in lora_path: {args.lora_path} or find a valid model_name_or_path."
99
config.dtype = lora_config.dtype
101
lora_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
102
) and args.device == "cpu":
103
raise ValueError("We can not apply bfloat16 or nf4/fp4 lora merge on cpu.")
105
if args.low_gpu_mem and args.device == "gpu":
106
quant_config = copy.deepcopy(config.quantization_config)
107
config.quantization_config = QuantizationConfig()
108
lora_config.merge_weights = False
110
model = AutoModelForCausalLM.from_pretrained(
111
lora_config.base_model_name_or_path,
113
low_cpu_mem_usage=True,
115
model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)
117
model_state_dict = model.model.state_dict()
119
for key in model_state_dict.keys():
121
lora_name_list.append(key[:-7])
122
for name in lora_name_list:
123
weight_process(name, quant_config, lora_config, model_state_dict)
125
model = AutoModelForCausalLM.from_pretrained(
126
lora_config.base_model_name_or_path,
128
low_cpu_mem_usage=True,
130
lora_config.merge_weights = True
131
model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)
133
model_state_dict = model.model.state_dict()
134
for key in list(model_state_dict):
136
del model_state_dict[key]
138
del model_state_dict[key]
139
model.model.config.quantization_config = QuantizationConfig()
140
model.model.save_pretrained(args.merge_lora_model_path, state_dict=model_state_dict)
142
tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path)
143
tokenizer.save_pretrained(args.merge_lora_model_path)
146
if __name__ == "__main__":