paddlenlp

Форк
0
/
quant.py 
294 строки · 10.8 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import json
15
import os
16

17
import paddle
18
from paddle import nn
19
from paddle.distributed.fleet.meta_parallel import (
20
    ColumnParallelLinear,
21
    RowParallelLinear,
22
)
23
from paddle.quantization import PTQ, QAT, QuantConfig
24
from paddleslim.quant.advanced import (
25
    GPTQ,
26
    AutoClip,
27
    AWQSearch,
28
    EMASampler,
29
    MultiStepSampler,
30
    PieceWiseSearch,
31
    Shift,
32
    Smooth,
33
)
34
from paddleslim.quant.advanced.utils import find_parent_layer_and_sub_name
35
from paddleslim.quant.layers import (
36
    QuantizedColumnParallelLinear,
37
    QuantizedRowParallelLinear,
38
)
39
from paddleslim.quant.observers import (
40
    AbsMaxChannelWiseWeightObserver,
41
    AVGObserver,
42
    GroupWiseWeightObserver,
43
)
44
from paddleslim.quant.observers.abs_max_weight import (
45
    AbsMaxChannelWiseWeightObserverLayer,
46
)
47
from paddleslim.quant.observers.avg import AVGObserverLayer
48
from paddleslim.quant.observers.groupwise import GroupWiseWeightObserverLayer
49

50
from paddlenlp.peft import PrefixModelForCausalLM
51
from paddlenlp.peft.lora import (
52
    ColumnParallelLoRALinear,
53
    LoRALinear,
54
    RowParallelLoRALinear,
55
)
56
from paddlenlp.peft.lora.lora_quant_layers import (
57
    ColumnParallelQuantedLoRALinear,
58
    QuantedLoRALinear,
59
    RowParallelQuantedLoRALinear,
60
)
61
from paddlenlp.utils.log import logger
62

63

64
def create_qat_model(quant_args, model, dtype):
65
    from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
66
    from paddleslim.quant.quanters import (
67
        FakeQuanterChannelWiseAbsMaxObserver,
68
        PACTQuanter,
69
    )
70

71
    q_config = QuantConfig(activation=None, weight=None)
72
    q_config.add_qat_layer_mapping(LoRALinear, QuantedLoRALinear)
73
    q_config.add_qat_layer_mapping(RowParallelLoRALinear, RowParallelQuantedLoRALinear)
74
    q_config.add_qat_layer_mapping(ColumnParallelLoRALinear, ColumnParallelQuantedLoRALinear)
75
    if quant_args.quant_type == "a8w8":
76
        activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserver(), init_value=20.0, dtype=dtype)
77
        weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype="float32")
78
    elif quant_args.quant_type == "weight_only_int4":
79
        activation = None
80
        weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype="float32")
81
    elif quant_args.quant_type == "weight_only_int8":
82
        activation = None
83
        weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype="float32")
84
    else:
85
        raise ValueError("quant_type should be one of ['a8w8', 'weight_only_int4', 'weight_only_int8']")
86

87
    q_config.add_type_config(RowParallelLoRALinear, weight=weight, activation=activation)
88
    q_config.add_type_config(ColumnParallelLoRALinear, weight=weight, activation=activation)
89
    q_config.add_type_config(LoRALinear, weight=weight, activation=activation)
90
    q_config.add_type_config(nn.Linear, weight=weight, activation=activation)
91

92
    qat = QAT(q_config)
93
    model = qat.quantize(model, inplace=True)
94
    return model
95

96

97
def apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config):
98
    logger.info("***** Running Shift *****")
99
    shift_sampler = EMASampler() if quant_args.shift_sampler == "ema" else None
100
    shift = Shift(
101
        model=trainer.model,
102
        model_config=ptq_model_config,
103
        sample_function=shift_sampler,
104
        shift_all_linears=quant_args.shift_all_linears,
105
    )
106
    with paddle.no_grad():
107
        trainer.ptq_loop(
108
            ptq_dataloader,
109
            description="Shift",
110
            max_eval_iters=quant_args.shift_step,
111
        )
112
        shift.update_weight()
113
    del shift, shift_sampler
114
    logger.info("***** Shift done *****")
115

116

117
def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config):
118

119
    if quant_args.do_awq:
120
        logger.info("***** Running AWQ *****")
121
    else:
122
        logger.info("***** Running Smooth *****")
123
    smooth_sampler = MultiStepSampler() if quant_args.smooth_sampler == "multi_step" else None
124
    if quant_args.smooth_piecewise_search:
125
        search_func = PieceWiseSearch(
126
            k_piece=quant_args.smooth_k_piece,
127
            bits_length=8,
128
            search_piece=quant_args.smooth_search_piece,
129
            search_alpha_min=0.2,
130
            search_alpha_max=0.8,
131
            search_scale_min=1.0,
132
            search_scale_max=5.0,
133
            weight_quant_method="abs_max_channel_wise",
134
            act_quant_method="avg",
135
        )
136
    elif quant_args.do_awq:
137
        search_func = AWQSearch(
138
            n_grid=20,
139
            bits_length=4,
140
            weight_quant_method=quant_args.weight_quant_method,
141
        )
142
    else:
143
        search_func = None
144
    smooth = Smooth(
145
        trainer.model,
146
        ptq_model_config,
147
        alpha=0.5,
148
        smooth_all_linears=quant_args.smooth_all_linears,
149
        sample_function=smooth_sampler,
150
        search_function=search_func,
151
        smooth_method="awq" if quant_args.do_awq else "smoothquant",
152
    )
153
    with paddle.no_grad():
154
        trainer.ptq_loop(
155
            ptq_dataloader,
156
            description="Smooth",
157
            max_eval_iters=quant_args.smooth_step,
158
        )
159

160
        smooth.update_weight()
161
    del smooth, smooth_sampler, search_func
162
    logger.info("***** Smooth done *****")
163

164

165
def apply_autoclip(quant_args, trainer, ptq_dataloader):
166
    """
167
    AutoClip
168
    """
169
    print("-------------------Start AutoClip------------------")
170
    sampler = MultiStepSampler()
171
    auto_clip = AutoClip(
172
        trainer.model,
173
        weight_bits=4,
174
        weight_quant_method=quant_args.weight_quant_method,
175
        sample_function=sampler,
176
        n_grid=20,
177
        max_shrink=0.5,
178
    )
179
    with paddle.no_grad():
180
        trainer.ptq_loop(
181
            ptq_dataloader,
182
            description="AutoClip",
183
            max_eval_iters=quant_args.autoclip_step,
184
        )
185
        auto_clip.auto_clip()
186
    del sampler, auto_clip
187
    logger.info("***** AutoClip done *****")
188

189

190
def apply_ptq(quant_args, trainer, ptq_dataloader):
191
    logger.info("***** Running PTQ *****")
192
    q_config = QuantConfig(activation=None, weight=None)
193
    if quant_args.weight_quant_method == "abs_max_channel_wise":
194
        weight_observer = AbsMaxChannelWiseWeightObserver
195
    elif quant_args.weight_quant_method == "groupwise":
196
        weight_observer = GroupWiseWeightObserver
197
    else:
198
        raise ValueError("weight_quant_method should be one of ['abs_max_channel_wise', 'groupwise']")
199

200
    if quant_args.quant_type == "a8w8":
201
        activation = AVGObserver(quant_bits=8)
202
        weight = weight_observer(quant_bits=8)
203
    elif quant_args.quant_type == "weight_only_int4":
204
        activation = None
205
        weight = weight_observer(quant_bits=4)
206
    elif quant_args.quant_type == "weight_only_int8":
207
        activation = None
208
        weight = weight_observer(quant_bits=8)
209
    else:
210
        raise ValueError("quant_type should be one of ['a8w8', 'weight_only_int4', 'weight_only_int8']")
211

212
    q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear)
213
    q_config.add_qat_layer_mapping(RowParallelLinear, QuantizedRowParallelLinear)
214
    q_config.add_type_config(
215
        [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear, QuantedLoRALinear],
216
        activation=activation,
217
        weight=weight,
218
    )
219

220
    ptq = PTQ(q_config)
221
    trainer.model = ptq.quantize(trainer.model, inplace=True)
222
    trainer.ptq_loop(
223
        ptq_dataloader,
224
        description="PTQ",
225
        max_eval_iters=quant_args.ptq_step,
226
    )
227
    weight_scales = {}
228
    act_scales = {}
229
    for cur_name, cur_layer in trainer.model.named_sublayers():
230
        if isinstance(cur_layer, AbsMaxChannelWiseWeightObserverLayer):
231
            if "_observer" not in cur_name:
232
                weight_scales[cur_name] = cur_layer.scales().numpy().tolist()
233
        if isinstance(cur_layer, GroupWiseWeightObserverLayer):
234
            if "_observer" not in cur_name:
235
                weight_scales[cur_name] = cur_layer.scales().numpy().tolist()
236
        if isinstance(cur_layer, AVGObserverLayer):
237
            if "_observer" not in cur_name:
238
                act_scales[cur_name] = cur_layer.scales().numpy().tolist()
239
    weight_scales_path = os.path.join(trainer.args.output_dir, "weight_scales.json")
240
    with open(weight_scales_path, "w") as f:
241
        json.dump(weight_scales, f)
242
    logger.info(f"Weight scales saved in {weight_scales_path}.")
243

244
    act_scales_path = os.path.join(trainer.args.output_dir, "act_scales.json")
245
    with open(act_scales_path, "w") as f:
246
        json.dump(act_scales, f)
247
    logger.info(f"Activation scales saved in {act_scales_path}.")
248

249
    trainer.model = ptq.convert(trainer.model, inplace=True)
250
    logger.info("***** PTQ done *****")
251

252

253
def apply_gptq(quant_args, trainer, ptq_dataloader):
254
    logger.info("***** Running GPTQ *****")
255
    num_layer = 0
256
    model = trainer.model
257
    for cur_name, cur_layer in model.named_sublayers():
258
        if type(cur_layer) in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]:
259
            num_layer += 1
260
            logger.info(f"GPTQ layer: {num_layer}, {cur_name}")
261
            parent_layer, sub_name = find_parent_layer_and_sub_name(model, cur_name)
262
            cur_quant_layer = GPTQ(cur_layer)
263
            setattr(parent_layer, sub_name, cur_quant_layer)
264
            with paddle.no_grad():
265
                trainer.ptq_loop(
266
                    ptq_dataloader,
267
                    description="GPTQ",
268
                    max_eval_iters=quant_args.gptq_step,
269
                )
270
                cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True)
271
            del cur_quant_layer
272
            setattr(parent_layer, sub_name, cur_layer)
273
    logger.info("***** GPTQ done *****")
274

275

276
def get_ptq_model_config(model):
277
    if isinstance(model, PrefixModelForCausalLM):
278
        base_model_prefix = model.model.base_model_prefix
279
    else:
280
        base_model_prefix = model.base_model_prefix
281

282
    if base_model_prefix in ["chatglm"]:
283
        raise NotImplementedError(f"{model} does not support Shift or Smooth.")
284
    elif base_model_prefix == "chatglm_v2":
285
        model_config = {"fused_qkv": False, "parallel_ffn": False, "skip_norm_list": ["rms_norm_56"]}
286
    elif base_model_prefix == "bloom":
287
        model_config = {"fused_qkv": True, "parallel_ffn": False}
288
    elif base_model_prefix == "llama":
289
        model_config = {"fused_qkv": False, "parallel_ffn": True}
290
    else:
291
        raise ValueError(
292
            f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm_V2, bloom, llama."
293
        )
294
    return model_config
295

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.