intel-extension-for-pytorch
248 строк · 8.4 Кб
1import unittest
2import torch
3import intel_extension_for_pytorch as ipex
4import sys
5import subprocess
6import os
7import copy
8import re
9from collections import namedtuple
10import itertools
11
12from hf_configs.baichuan.modeling_baichuan import BaichuanForCausalLM
13from hf_configs.chatglm.modeling_chatglm import ChatGLMForConditionalGeneration
14from hf_configs.qwen.modeling_qwen import QWenLMHeadModel
15from intel_extension_for_pytorch.cpu._auto_kernel_selection import _disable_tpp
16
17try:
18import transformers
19from transformers import AutoConfig
20except ImportError:
21subprocess.check_call(
22[sys.executable, "-m", "pip", "install", "transformers==4.37.0"]
23)
24import transformers
25from transformers import AutoConfig
26
27from common_utils import TestCase
28
29torch.manual_seed(128)
30
31curpath = os.path.abspath(os.path.dirname(__file__))
32
33model_info = namedtuple(
34"model_info",
35"name, model_class, has_position_ids, attention_class, decoder_class",
36)
37supported_models = [
38model_info(
39"gptneox",
40transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM,
41True,
42lambda m: m.gpt_neox.layers[0].attention.__class__,
43None,
44),
45model_info(
46"opt",
47transformers.models.opt.modeling_opt.OPTForCausalLM,
48False,
49lambda m: m.model.decoder.layers[0].self_attn.__class__,
50lambda m: m.model.decoder.layers[0].__class__,
51),
52model_info(
53"falcon",
54transformers.models.falcon.modeling_falcon.FalconForCausalLM,
55False,
56lambda m: m.transformer.h[0].self_attention.__class__,
57lambda m: m.transformer.h[0].__class__,
58),
59model_info(
60"bloom",
61transformers.models.bloom.modeling_bloom.BloomForCausalLM,
62False,
63lambda m: m.transformer.h[0].self_attention.__class__,
64lambda m: m.transformer.h[0].__class__,
65),
66model_info(
67"codegen",
68transformers.models.codegen.modeling_codegen.CodeGenForCausalLM,
69True,
70lambda m: m.transformer.h[0].attn.__class__,
71lambda m: m.transformer.h[0].__class__,
72),
73model_info(
74"baichuan",
75BaichuanForCausalLM,
76False,
77lambda m: m.model.layers[0].self_attn.__class__,
78lambda m: m.model.layers[0].__class__,
79),
80model_info(
81"chatglm",
82ChatGLMForConditionalGeneration,
83False,
84lambda m: m.transformer.encoder.layers[0].self_attention.__class__,
85lambda m: m.transformer.encoder.layers[0].__class__,
86),
87model_info(
88"gptbigcode",
89transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM,
90True,
91lambda m: m.transformer.h[0].attn.__class__,
92lambda m: m.transformer.h[0].__class__,
93),
94model_info(
95"t5",
96transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
97False,
98lambda m: m.decoder.block[0].layer[0].SelfAttention.__class__,
99lambda m: m.decoder.block[0].__class__,
100),
101model_info(
102"mistral",
103transformers.models.mistral.modeling_mistral.MistralForCausalLM,
104True,
105lambda m: m.model.layers[0].self_attn.__class__,
106lambda m: m.model.layers[0].__class__,
107),
108model_info(
109"mpt",
110transformers.models.mpt.modeling_mpt.MptForCausalLM,
111False,
112lambda m: m.transformer.blocks[0].attn.__class__,
113lambda m: m.transformer.blocks[0].__class__,
114),
115model_info(
116"mixtral",
117transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM,
118True,
119lambda m: m.model.layers[0].self_attn.__class__,
120lambda m: m.model.layers[0].__class__,
121),
122model_info(
123"stablelm",
124transformers.models.stablelm.modeling_stablelm.StableLmForCausalLM,
125True,
126lambda m: m.model.layers[0].self_attn.__class__,
127lambda m: m.model.layers[0].__class__,
128),
129model_info(
130"qwen",
131QWenLMHeadModel,
132False,
133lambda m: m.transformer.h[0].attn.__class__,
134lambda m: m.transformer.h[0].__class__,
135),
136model_info(
137"git",
138transformers.models.git.modeling_git.GitForCausalLM,
139False,
140lambda m: m.git.encoder.layer[0].attention.self.__class__,
141lambda m: m.git.encoder.layer[0].__class__,
142),
143]
144
145
146class OptimizeTransformersNightlyTester(TestCase):
147def model_replacement_check(
148self, m, dtype, deployment_mode, torchcompile=False, return_dict=False
149):
150config = AutoConfig.from_pretrained(
151f"{curpath}/hf_configs/{m.name}",
152return_dict=return_dict,
153trust_remote_code=True,
154)
155model = m.model_class(config).eval()
156if m.name == "falcon":
157with torch.no_grad():
158ipex.nn.utils._model_convert.replace_customized_linear_with_linear(
159model.eval()
160)
161elif m.name == "chatglm":
162state_dict = model.state_dict()
163for weight in [
164"transformer.encoder.layers.0.input_layernorm.weight",
165"transformer.encoder.layers.0.post_attention_layernorm.weight",
166"transformer.encoder.final_layernorm.weight",
167]:
168state_dict[weight] = torch.rand(state_dict[weight].shape)
169model.load_state_dict(state_dict)
170elif m.name == "baichuan":
171state_dict = model.state_dict()
172for weight in [
173"model.layers.0.input_layernorm.weight",
174"model.layers.0.post_attention_layernorm.weight",
175"model.norm.weight",
176]:
177state_dict[weight] = torch.rand(state_dict[weight].shape)
178model.load_state_dict(state_dict)
179model.eval()
180ref_m = copy.deepcopy(model)
181ipex_m = copy.deepcopy(model)
182ipex_m = ipex.llm.optimize(
183ipex_m, dtype=dtype, deployment_mode=deployment_mode, inplace=True
184)
185if torchcompile:
186torch._dynamo.reset()
187ipex._set_compiler_backend("inductor")
188ipex_m = torch.compile(ipex_m, backend="ipex")
189
190assert (
191m.attention_class(ipex_m)
192is ipex.transformers.models.cpu.modules.attentions._IPEXAttentionCPU
193)
194assert (
195m.decoder_class(ipex_m)
196is ipex.transformers.models.cpu.modules.decoder._IPEXDecoderLayerCPU
197if m.decoder_class is not None
198else True
199)
200
201input_ids = torch.ones(10).to(torch.long)
202attention_mask = torch.ones(len(input_ids))
203position_ids = torch.arange(len(input_ids))
204decoder_input_ids = torch.ones(1).to(torch.long)
205input_dict = {
206"input_ids": input_ids.unsqueeze(0),
207"attention_mask": attention_mask.unsqueeze(0),
208"use_cache": True,
209}
210if m.has_position_ids:
211input_dict["position_ids"] = position_ids.unsqueeze(0)
212if re.search("t5", model.config.architectures[0], re.IGNORECASE):
213input_dict["decoder_input_ids"] = decoder_input_ids.unsqueeze(0)
214if m.name == "git":
215input_dict["pixel_values"] = torch.zeros(1, 3, 224, 224)
216
217with torch.no_grad():
218key_hf = ref_m(**input_dict)
219with torch.no_grad(), torch.cpu.amp.autocast(
220enabled=True if dtype is torch.bfloat16 else False
221):
222key_ipex = ipex_m(**input_dict)
223error_message = f"model={m.name}, deployment_mode={deployment_mode}, torchcompile={torchcompile}, return_dict={return_dict}"
224if return_dict:
225assert isinstance(key_ipex, dict)
226self.assertEqual(
227key_hf["logits"], key_ipex["logits"], prec=0.1, message=error_message
228)
229else:
230assert isinstance(key_ipex, tuple)
231self.assertEqual(key_hf[0], key_ipex[0], prec=0.1, message=error_message)
232
233def test_model_replacement(self):
234dtypes = [torch.bfloat16]
235enable_torchcompile = [False, True]
236deployment_mode = [True, False]
237return_dict = [False, True]
238for m, torchcompile, dtype, jit, return_dict in itertools.product(
239supported_models, enable_torchcompile, dtypes, deployment_mode, return_dict
240):
241if torchcompile and deployment_mode:
242continue
243self.model_replacement_check(m, dtype, jit, torchcompile, return_dict)
244_disable_tpp()
245
246
247if __name__ == "__main__":
248test = unittest.main()
249