intel-extension-for-pytorch

Форк
0
/
test_ipex_optimize_transformers_nightly.py 
248 строк · 8.4 Кб
1
import unittest
2
import torch
3
import intel_extension_for_pytorch as ipex
4
import sys
5
import subprocess
6
import os
7
import copy
8
import re
9
from collections import namedtuple
10
import itertools
11

12
from hf_configs.baichuan.modeling_baichuan import BaichuanForCausalLM
13
from hf_configs.chatglm.modeling_chatglm import ChatGLMForConditionalGeneration
14
from hf_configs.qwen.modeling_qwen import QWenLMHeadModel
15
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _disable_tpp
16

17
try:
18
    import transformers
19
    from transformers import AutoConfig
20
except ImportError:
21
    subprocess.check_call(
22
        [sys.executable, "-m", "pip", "install", "transformers==4.37.0"]
23
    )
24
    import transformers
25
    from transformers import AutoConfig
26

27
from common_utils import TestCase
28

29
torch.manual_seed(128)
30

31
curpath = os.path.abspath(os.path.dirname(__file__))
32

33
model_info = namedtuple(
34
    "model_info",
35
    "name, model_class, has_position_ids, attention_class, decoder_class",
36
)
37
supported_models = [
38
    model_info(
39
        "gptneox",
40
        transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM,
41
        True,
42
        lambda m: m.gpt_neox.layers[0].attention.__class__,
43
        None,
44
    ),
45
    model_info(
46
        "opt",
47
        transformers.models.opt.modeling_opt.OPTForCausalLM,
48
        False,
49
        lambda m: m.model.decoder.layers[0].self_attn.__class__,
50
        lambda m: m.model.decoder.layers[0].__class__,
51
    ),
52
    model_info(
53
        "falcon",
54
        transformers.models.falcon.modeling_falcon.FalconForCausalLM,
55
        False,
56
        lambda m: m.transformer.h[0].self_attention.__class__,
57
        lambda m: m.transformer.h[0].__class__,
58
    ),
59
    model_info(
60
        "bloom",
61
        transformers.models.bloom.modeling_bloom.BloomForCausalLM,
62
        False,
63
        lambda m: m.transformer.h[0].self_attention.__class__,
64
        lambda m: m.transformer.h[0].__class__,
65
    ),
66
    model_info(
67
        "codegen",
68
        transformers.models.codegen.modeling_codegen.CodeGenForCausalLM,
69
        True,
70
        lambda m: m.transformer.h[0].attn.__class__,
71
        lambda m: m.transformer.h[0].__class__,
72
    ),
73
    model_info(
74
        "baichuan",
75
        BaichuanForCausalLM,
76
        False,
77
        lambda m: m.model.layers[0].self_attn.__class__,
78
        lambda m: m.model.layers[0].__class__,
79
    ),
80
    model_info(
81
        "chatglm",
82
        ChatGLMForConditionalGeneration,
83
        False,
84
        lambda m: m.transformer.encoder.layers[0].self_attention.__class__,
85
        lambda m: m.transformer.encoder.layers[0].__class__,
86
    ),
87
    model_info(
88
        "gptbigcode",
89
        transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM,
90
        True,
91
        lambda m: m.transformer.h[0].attn.__class__,
92
        lambda m: m.transformer.h[0].__class__,
93
    ),
94
    model_info(
95
        "t5",
96
        transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
97
        False,
98
        lambda m: m.decoder.block[0].layer[0].SelfAttention.__class__,
99
        lambda m: m.decoder.block[0].__class__,
100
    ),
101
    model_info(
102
        "mistral",
103
        transformers.models.mistral.modeling_mistral.MistralForCausalLM,
104
        True,
105
        lambda m: m.model.layers[0].self_attn.__class__,
106
        lambda m: m.model.layers[0].__class__,
107
    ),
108
    model_info(
109
        "mpt",
110
        transformers.models.mpt.modeling_mpt.MptForCausalLM,
111
        False,
112
        lambda m: m.transformer.blocks[0].attn.__class__,
113
        lambda m: m.transformer.blocks[0].__class__,
114
    ),
115
    model_info(
116
        "mixtral",
117
        transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM,
118
        True,
119
        lambda m: m.model.layers[0].self_attn.__class__,
120
        lambda m: m.model.layers[0].__class__,
121
    ),
122
    model_info(
123
        "stablelm",
124
        transformers.models.stablelm.modeling_stablelm.StableLmForCausalLM,
125
        True,
126
        lambda m: m.model.layers[0].self_attn.__class__,
127
        lambda m: m.model.layers[0].__class__,
128
    ),
129
    model_info(
130
        "qwen",
131
        QWenLMHeadModel,
132
        False,
133
        lambda m: m.transformer.h[0].attn.__class__,
134
        lambda m: m.transformer.h[0].__class__,
135
    ),
136
    model_info(
137
        "git",
138
        transformers.models.git.modeling_git.GitForCausalLM,
139
        False,
140
        lambda m: m.git.encoder.layer[0].attention.self.__class__,
141
        lambda m: m.git.encoder.layer[0].__class__,
142
    ),
143
]
144

145

146
class OptimizeTransformersNightlyTester(TestCase):
147
    def model_replacement_check(
148
        self, m, dtype, deployment_mode, torchcompile=False, return_dict=False
149
    ):
150
        config = AutoConfig.from_pretrained(
151
            f"{curpath}/hf_configs/{m.name}",
152
            return_dict=return_dict,
153
            trust_remote_code=True,
154
        )
155
        model = m.model_class(config).eval()
156
        if m.name == "falcon":
157
            with torch.no_grad():
158
                ipex.nn.utils._model_convert.replace_customized_linear_with_linear(
159
                    model.eval()
160
                )
161
        elif m.name == "chatglm":
162
            state_dict = model.state_dict()
163
            for weight in [
164
                "transformer.encoder.layers.0.input_layernorm.weight",
165
                "transformer.encoder.layers.0.post_attention_layernorm.weight",
166
                "transformer.encoder.final_layernorm.weight",
167
            ]:
168
                state_dict[weight] = torch.rand(state_dict[weight].shape)
169
            model.load_state_dict(state_dict)
170
        elif m.name == "baichuan":
171
            state_dict = model.state_dict()
172
            for weight in [
173
                "model.layers.0.input_layernorm.weight",
174
                "model.layers.0.post_attention_layernorm.weight",
175
                "model.norm.weight",
176
            ]:
177
                state_dict[weight] = torch.rand(state_dict[weight].shape)
178
            model.load_state_dict(state_dict)
179
        model.eval()
180
        ref_m = copy.deepcopy(model)
181
        ipex_m = copy.deepcopy(model)
182
        ipex_m = ipex.llm.optimize(
183
            ipex_m, dtype=dtype, deployment_mode=deployment_mode, inplace=True
184
        )
185
        if torchcompile:
186
            torch._dynamo.reset()
187
            ipex._set_compiler_backend("inductor")
188
            ipex_m = torch.compile(ipex_m, backend="ipex")
189

190
        assert (
191
            m.attention_class(ipex_m)
192
            is ipex.transformers.models.cpu.modules.attentions._IPEXAttentionCPU
193
        )
194
        assert (
195
            m.decoder_class(ipex_m)
196
            is ipex.transformers.models.cpu.modules.decoder._IPEXDecoderLayerCPU
197
            if m.decoder_class is not None
198
            else True
199
        )
200

201
        input_ids = torch.ones(10).to(torch.long)
202
        attention_mask = torch.ones(len(input_ids))
203
        position_ids = torch.arange(len(input_ids))
204
        decoder_input_ids = torch.ones(1).to(torch.long)
205
        input_dict = {
206
            "input_ids": input_ids.unsqueeze(0),
207
            "attention_mask": attention_mask.unsqueeze(0),
208
            "use_cache": True,
209
        }
210
        if m.has_position_ids:
211
            input_dict["position_ids"] = position_ids.unsqueeze(0)
212
        if re.search("t5", model.config.architectures[0], re.IGNORECASE):
213
            input_dict["decoder_input_ids"] = decoder_input_ids.unsqueeze(0)
214
        if m.name == "git":
215
            input_dict["pixel_values"] = torch.zeros(1, 3, 224, 224)
216

217
        with torch.no_grad():
218
            key_hf = ref_m(**input_dict)
219
        with torch.no_grad(), torch.cpu.amp.autocast(
220
            enabled=True if dtype is torch.bfloat16 else False
221
        ):
222
            key_ipex = ipex_m(**input_dict)
223
        error_message = f"model={m.name}, deployment_mode={deployment_mode}, torchcompile={torchcompile}, return_dict={return_dict}"
224
        if return_dict:
225
            assert isinstance(key_ipex, dict)
226
            self.assertEqual(
227
                key_hf["logits"], key_ipex["logits"], prec=0.1, message=error_message
228
            )
229
        else:
230
            assert isinstance(key_ipex, tuple)
231
            self.assertEqual(key_hf[0], key_ipex[0], prec=0.1, message=error_message)
232

233
    def test_model_replacement(self):
234
        dtypes = [torch.bfloat16]
235
        enable_torchcompile = [False, True]
236
        deployment_mode = [True, False]
237
        return_dict = [False, True]
238
        for m, torchcompile, dtype, jit, return_dict in itertools.product(
239
            supported_models, enable_torchcompile, dtypes, deployment_mode, return_dict
240
        ):
241
            if torchcompile and deployment_mode:
242
                continue
243
            self.model_replacement_check(m, dtype, jit, torchcompile, return_dict)
244
        _disable_tpp()
245

246

247
if __name__ == "__main__":
248
    test = unittest.main()
249

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.