intel-extension-for-pytorch

Форк
0
/
test_gptq_nightly.py 
130 строк · 4.8 Кб
1
import tempfile
2
import torch
3

4
import copy
5
import os
6
import unittest
7
import transformers
8
from transformers import AutoConfig
9

10
from common_utils import TestCase
11

12
import intel_extension_for_pytorch as ipex
13

14
from torch.testing._internal.common_utils import run_tests
15
from intel_extension_for_pytorch.quantization import (
16
    WoqWeightDtype,
17
    WoqLowpMode,
18
)
19

20

21
class GPTQLLMTester(TestCase):
22
    def test_gptq_quantize(self):
23
        class GPTQLLMDataLoader:
24
            def __init__(self):
25
                self.batch_size = 1
26

27
            def __iter__(self):
28
                for i in range(10):
29
                    yield torch.ones([1, 512], dtype=torch.long)
30

31
        def _get_gptj_example_inputs():
32
            input_ids = torch.ones(8).to(torch.long)
33
            attention_mask = torch.ones(len(input_ids))
34
            position_ids = torch.arange(len(input_ids))
35
            past_key_values = tuple(
36
                [
37
                    (
38
                        torch.zeros(1, 1, 0, 1, dtype=torch.long).contiguous(),
39
                        torch.zeros([1, 1, 1, 1]).contiguous(),
40
                        torch.zeros([1, 1, 1, 1]).contiguous(),
41
                        torch.zeros(1, 4, dtype=torch.long),
42
                    )
43
                    for i in range(1)
44
                ]
45
            )
46
            return (
47
                input_ids.unsqueeze(0),
48
                attention_mask.unsqueeze(0),
49
                past_key_values,
50
                position_ids.unsqueeze(0),
51
            )
52

53
        dataloader = GPTQLLMDataLoader()
54
        curpath = os.path.abspath(os.path.dirname(__file__))
55
        config = AutoConfig.from_pretrained(
56
            f"{curpath}/hf_configs/gptj", return_dict=False
57
        )
58
        gptj = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
59
        with tempfile.TemporaryDirectory() as work_dir:
60
            for act_order in [False, True]:
61
                model = copy.deepcopy(gptj)
62
                model.eval()
63
                compressed_model = ipex.quantization.gptq(
64
                    model,
65
                    dataloader=dataloader,
66
                    wbits=4,
67
                    group_size=128,
68
                    act_order=act_order,
69
                    use_max_length=True,
70
                    pad_max_length=512,
71
                    scale_dtype=torch.float16,
72
                    save_dir=work_dir,
73
                )
74
                self.assertTrue(isinstance(compressed_model, torch.nn.Module))
75
                input = torch.ones([1, 512], dtype=torch.long)
76
                out0 = model(input)
77
                out1 = compressed_model(input)
78
                self.assertTrue(torch.allclose(out0[0], out1[0], atol=1e-05))
79

80
                low_precision_checkpoint = torch.load(
81
                    work_dir + "/gptq_checkpoint_g128.pt"
82
                )
83
                qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
84
                    weight_dtype=WoqWeightDtype.INT4,
85
                    lowp_mode=WoqLowpMode.INT8,
86
                )
87
                model = copy.deepcopy(gptj)
88
                model.eval()
89
                model = ipex.llm.optimize(
90
                    model,
91
                    dtype=torch.float,
92
                    quantization_config=qconfig,
93
                    inplace=True,
94
                    low_precision_checkpoint=low_precision_checkpoint,
95
                    deployment_mode=False,
96
                )
97
                _IPEXAttentionCPU = (
98
                    ipex.transformers.models.cpu.modules.attentions._IPEXAttentionCPU
99
                )
100
                _IPEXDecoderLayerCPU = (
101
                    ipex.transformers.models.cpu.modules.decoder._IPEXDecoderLayerCPU
102
                )
103
                WeightOnlyQuantizedLinear = ipex.nn.modules.WeightOnlyQuantizedLinear
104
                assert model.transformer.h[0].attn.__class__ is _IPEXAttentionCPU
105
                assert model.transformer.h[0].__class__ is _IPEXDecoderLayerCPU
106
                layers_to_check = [
107
                    model.transformer.h[0].attn.out_proj,
108
                    model.transformer.h[0].linear_add_add.linear,
109
                    model.transformer.h[0].linear_gelu.linear,
110
                ]
111
                # concat linear is unsupported with act_order
112
                if not act_order:
113
                    layers_to_check.append(
114
                        model.transformer.h[0].attn.concat_qkv.concat_linear
115
                    )
116
                assert all(
117
                    mod.__class__ is WeightOnlyQuantizedLinear
118
                    for mod in layers_to_check
119
                )
120

121
                # Ensure model can run without errors
122
                with torch.no_grad():
123
                    example_inputs = _get_gptj_example_inputs()
124
                    # the optimized model is ipex_m.trace_graph
125
                    model(*example_inputs)
126

127

128
if __name__ == "__main__":
129
    test = unittest.main()
130
    run_tests()
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.