1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
19
from paddle.distributed.fleet.meta_parallel import (
23
from paddle.quantization import PTQ, QAT, QuantConfig
24
from paddleslim.quant.advanced import (
34
from paddleslim.quant.advanced.utils import find_parent_layer_and_sub_name
35
from paddleslim.quant.layers import (
36
QuantizedColumnParallelLinear,
37
QuantizedRowParallelLinear,
39
from paddleslim.quant.observers import (
40
AbsMaxChannelWiseWeightObserver,
42
GroupWiseWeightObserver,
44
from paddleslim.quant.observers.abs_max_weight import (
45
AbsMaxChannelWiseWeightObserverLayer,
47
from paddleslim.quant.observers.avg import AVGObserverLayer
48
from paddleslim.quant.observers.groupwise import GroupWiseWeightObserverLayer
50
from paddlenlp.peft import PrefixModelForCausalLM
51
from paddlenlp.peft.lora import (
52
ColumnParallelLoRALinear,
54
RowParallelLoRALinear,
56
from paddlenlp.peft.lora.lora_quant_layers import (
57
ColumnParallelQuantedLoRALinear,
59
RowParallelQuantedLoRALinear,
61
from paddlenlp.utils.log import logger
64
def create_qat_model(quant_args, model, dtype):
65
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
66
from paddleslim.quant.quanters import (
67
FakeQuanterChannelWiseAbsMaxObserver,
71
q_config = QuantConfig(activation=None, weight=None)
72
q_config.add_qat_layer_mapping(LoRALinear, QuantedLoRALinear)
73
q_config.add_qat_layer_mapping(RowParallelLoRALinear, RowParallelQuantedLoRALinear)
74
q_config.add_qat_layer_mapping(ColumnParallelLoRALinear, ColumnParallelQuantedLoRALinear)
75
if quant_args.quant_type == "a8w8":
76
activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserver(), init_value=20.0, dtype=dtype)
77
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype="float32")
78
elif quant_args.quant_type == "weight_only_int4":
80
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype="float32")
81
elif quant_args.quant_type == "weight_only_int8":
83
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype="float32")
85
raise ValueError("quant_type should be one of ['a8w8', 'weight_only_int4', 'weight_only_int8']")
87
q_config.add_type_config(RowParallelLoRALinear, weight=weight, activation=activation)
88
q_config.add_type_config(ColumnParallelLoRALinear, weight=weight, activation=activation)
89
q_config.add_type_config(LoRALinear, weight=weight, activation=activation)
90
q_config.add_type_config(nn.Linear, weight=weight, activation=activation)
93
model = qat.quantize(model, inplace=True)
97
def apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config):
98
logger.info("***** Running Shift *****")
99
shift_sampler = EMASampler() if quant_args.shift_sampler == "ema" else None
102
model_config=ptq_model_config,
103
sample_function=shift_sampler,
104
shift_all_linears=quant_args.shift_all_linears,
106
with paddle.no_grad():
110
max_eval_iters=quant_args.shift_step,
112
shift.update_weight()
113
del shift, shift_sampler
114
logger.info("***** Shift done *****")
117
def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config):
119
if quant_args.do_awq:
120
logger.info("***** Running AWQ *****")
122
logger.info("***** Running Smooth *****")
123
smooth_sampler = MultiStepSampler() if quant_args.smooth_sampler == "multi_step" else None
124
if quant_args.smooth_piecewise_search:
125
search_func = PieceWiseSearch(
126
k_piece=quant_args.smooth_k_piece,
128
search_piece=quant_args.smooth_search_piece,
129
search_alpha_min=0.2,
130
search_alpha_max=0.8,
131
search_scale_min=1.0,
132
search_scale_max=5.0,
133
weight_quant_method="abs_max_channel_wise",
134
act_quant_method="avg",
136
elif quant_args.do_awq:
137
search_func = AWQSearch(
140
weight_quant_method=quant_args.weight_quant_method,
148
smooth_all_linears=quant_args.smooth_all_linears,
149
sample_function=smooth_sampler,
150
search_function=search_func,
151
smooth_method="awq" if quant_args.do_awq else "smoothquant",
153
with paddle.no_grad():
156
description="Smooth",
157
max_eval_iters=quant_args.smooth_step,
160
smooth.update_weight()
161
del smooth, smooth_sampler, search_func
162
logger.info("***** Smooth done *****")
165
def apply_autoclip(quant_args, trainer, ptq_dataloader):
169
print("-------------------Start AutoClip------------------")
170
sampler = MultiStepSampler()
171
auto_clip = AutoClip(
174
weight_quant_method=quant_args.weight_quant_method,
175
sample_function=sampler,
179
with paddle.no_grad():
182
description="AutoClip",
183
max_eval_iters=quant_args.autoclip_step,
185
auto_clip.auto_clip()
186
del sampler, auto_clip
187
logger.info("***** AutoClip done *****")
190
def apply_ptq(quant_args, trainer, ptq_dataloader):
191
logger.info("***** Running PTQ *****")
192
q_config = QuantConfig(activation=None, weight=None)
193
if quant_args.weight_quant_method == "abs_max_channel_wise":
194
weight_observer = AbsMaxChannelWiseWeightObserver
195
elif quant_args.weight_quant_method == "groupwise":
196
weight_observer = GroupWiseWeightObserver
198
raise ValueError("weight_quant_method should be one of ['abs_max_channel_wise', 'groupwise']")
200
if quant_args.quant_type == "a8w8":
201
activation = AVGObserver(quant_bits=8)
202
weight = weight_observer(quant_bits=8)
203
elif quant_args.quant_type == "weight_only_int4":
205
weight = weight_observer(quant_bits=4)
206
elif quant_args.quant_type == "weight_only_int8":
208
weight = weight_observer(quant_bits=8)
210
raise ValueError("quant_type should be one of ['a8w8', 'weight_only_int4', 'weight_only_int8']")
212
q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear)
213
q_config.add_qat_layer_mapping(RowParallelLinear, QuantizedRowParallelLinear)
214
q_config.add_type_config(
215
[paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear, QuantedLoRALinear],
216
activation=activation,
221
trainer.model = ptq.quantize(trainer.model, inplace=True)
225
max_eval_iters=quant_args.ptq_step,
229
for cur_name, cur_layer in trainer.model.named_sublayers():
230
if isinstance(cur_layer, AbsMaxChannelWiseWeightObserverLayer):
231
if "_observer" not in cur_name:
232
weight_scales[cur_name] = cur_layer.scales().numpy().tolist()
233
if isinstance(cur_layer, GroupWiseWeightObserverLayer):
234
if "_observer" not in cur_name:
235
weight_scales[cur_name] = cur_layer.scales().numpy().tolist()
236
if isinstance(cur_layer, AVGObserverLayer):
237
if "_observer" not in cur_name:
238
act_scales[cur_name] = cur_layer.scales().numpy().tolist()
239
weight_scales_path = os.path.join(trainer.args.output_dir, "weight_scales.json")
240
with open(weight_scales_path, "w") as f:
241
json.dump(weight_scales, f)
242
logger.info(f"Weight scales saved in {weight_scales_path}.")
244
act_scales_path = os.path.join(trainer.args.output_dir, "act_scales.json")
245
with open(act_scales_path, "w") as f:
246
json.dump(act_scales, f)
247
logger.info(f"Activation scales saved in {act_scales_path}.")
249
trainer.model = ptq.convert(trainer.model, inplace=True)
250
logger.info("***** PTQ done *****")
253
def apply_gptq(quant_args, trainer, ptq_dataloader):
254
logger.info("***** Running GPTQ *****")
256
model = trainer.model
257
for cur_name, cur_layer in model.named_sublayers():
258
if type(cur_layer) in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]:
260
logger.info(f"GPTQ layer: {num_layer}, {cur_name}")
261
parent_layer, sub_name = find_parent_layer_and_sub_name(model, cur_name)
262
cur_quant_layer = GPTQ(cur_layer)
263
setattr(parent_layer, sub_name, cur_quant_layer)
264
with paddle.no_grad():
268
max_eval_iters=quant_args.gptq_step,
270
cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True)
272
setattr(parent_layer, sub_name, cur_layer)
273
logger.info("***** GPTQ done *****")
276
def get_ptq_model_config(model):
277
if isinstance(model, PrefixModelForCausalLM):
278
base_model_prefix = model.model.base_model_prefix
280
base_model_prefix = model.base_model_prefix
282
if base_model_prefix in ["chatglm"]:
283
raise NotImplementedError(f"{model} does not support Shift or Smooth.")
284
elif base_model_prefix == "chatglm_v2":
285
model_config = {"fused_qkv": False, "parallel_ffn": False, "skip_norm_list": ["rms_norm_56"]}
286
elif base_model_prefix == "bloom":
287
model_config = {"fused_qkv": True, "parallel_ffn": False}
288
elif base_model_prefix == "llama":
289
model_config = {"fused_qkv": False, "parallel_ffn": True}
292
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm_V2, bloom, llama."