paddlenlp

Форк
0
/
compression_helper.py 
70 строк · 3.1 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import paddle
16
import paddleslim
17

18

19
def get_pruned_params(model):
20
    params = []
21
    for sublayer in model.sublayers():
22
        for param in sublayer.parameters(include_sublayers=False):
23
            if (
24
                isinstance(sublayer, paddle.nn.layer.common.Linear)
25
                or isinstance(sublayer, paddle.distributed.fleet.layers.mpu.mp_layers.ColumnParallelLinear)
26
                or isinstance(sublayer, paddle.distributed.fleet.layers.mpu.mp_layers.RowParallelLinear)
27
            ):
28
                if len(param.shape) != 2:
29
                    continue
30

31
                # NOTE(minghaoBD):
32
                # 1. param.shape[1] == 3 * param.shape[0]: prune fused-qkv's weight and its next weight: out-linear's weight
33
                # 2. param.shape[1] == 4 * param.shape[0]: prune ffn1's weight and its next weight: ffn2's weight
34
                # If your model has a different architecture, like your qkv's weights are not fused or ffn1_weight.shape[1] != 4*ffn1_weight.shape[0], you may need to customize this function to suit your model.
35
                if param.shape[1] == 3 * param.shape[0] or param.shape[1] == 4 * param.shape[0]:
36
                    params.append(param.name)
37

38
    return params
39

40

41
def prune_model(model, configs, inputs_desc=[]):
42
    prune_criterion = configs.criterion
43
    ratio = configs.ratio
44
    shapes, dtypes = [], []
45
    for input_desc in inputs_desc:
46
        dtypes.append(input_desc.dtype)
47
        new_shape = [10 if item == -1 else item for item in input_desc.shape]
48
        shapes.append(new_shape)
49
    # TODO(minghaoBD): support ViT and other model architectures in the future
50
    num_attention_heads = model.gpt.decoder.layers[0].self_attn.num_heads
51

52
    if prune_criterion == "l1_norm":
53
        pruner = paddleslim.L1NormFilterPruner(
54
            model, shapes, skip_leaves=False, prune_type="fc", input_dtype=dtypes[0], num_head=num_attention_heads
55
        )
56
    elif prune_criterion == "l2_norm":
57
        pruner = paddleslim.L2NormFilterPruner(
58
            model, shapes, skip_leaves=False, prune_type="fc", input_dtype=dtypes[0], num_head=num_attention_heads
59
        )
60
    params = get_pruned_params(model)
61
    ratios = {}
62
    for param in params:
63
        ratios[param] = ratio
64
    # NOTE(minghaoBD): hidden size in Layernorm must be 768/1024/2048/4096 for best inference performace, and when axis=0, the hidden size in layernorm will be changed accordingly. So axis=1 is required.
65
    pruner.prune_vars(ratios, [1])
66

67

68
def quant_model(model, configs):
69
    quanter = paddleslim.dygraph.quant.QAT(configs)
70
    return quanter.quantize(model), quanter
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.