intel-extension-for-pytorch
130 строк · 4.8 Кб
1import tempfile2import torch3
4import copy5import os6import unittest7import transformers8from transformers import AutoConfig9
10from common_utils import TestCase11
12import intel_extension_for_pytorch as ipex13
14from torch.testing._internal.common_utils import run_tests15from intel_extension_for_pytorch.quantization import (16WoqWeightDtype,17WoqLowpMode,18)
19
20
21class GPTQLLMTester(TestCase):22def test_gptq_quantize(self):23class GPTQLLMDataLoader:24def __init__(self):25self.batch_size = 126
27def __iter__(self):28for i in range(10):29yield torch.ones([1, 512], dtype=torch.long)30
31def _get_gptj_example_inputs():32input_ids = torch.ones(8).to(torch.long)33attention_mask = torch.ones(len(input_ids))34position_ids = torch.arange(len(input_ids))35past_key_values = tuple(36[37(38torch.zeros(1, 1, 0, 1, dtype=torch.long).contiguous(),39torch.zeros([1, 1, 1, 1]).contiguous(),40torch.zeros([1, 1, 1, 1]).contiguous(),41torch.zeros(1, 4, dtype=torch.long),42)43for i in range(1)44]45)46return (47input_ids.unsqueeze(0),48attention_mask.unsqueeze(0),49past_key_values,50position_ids.unsqueeze(0),51)52
53dataloader = GPTQLLMDataLoader()54curpath = os.path.abspath(os.path.dirname(__file__))55config = AutoConfig.from_pretrained(56f"{curpath}/hf_configs/gptj", return_dict=False57)58gptj = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()59with tempfile.TemporaryDirectory() as work_dir:60for act_order in [False, True]:61model = copy.deepcopy(gptj)62model.eval()63compressed_model = ipex.quantization.gptq(64model,65dataloader=dataloader,66wbits=4,67group_size=128,68act_order=act_order,69use_max_length=True,70pad_max_length=512,71scale_dtype=torch.float16,72save_dir=work_dir,73)74self.assertTrue(isinstance(compressed_model, torch.nn.Module))75input = torch.ones([1, 512], dtype=torch.long)76out0 = model(input)77out1 = compressed_model(input)78self.assertTrue(torch.allclose(out0[0], out1[0], atol=1e-05))79
80low_precision_checkpoint = torch.load(81work_dir + "/gptq_checkpoint_g128.pt"82)83qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(84weight_dtype=WoqWeightDtype.INT4,85lowp_mode=WoqLowpMode.INT8,86)87model = copy.deepcopy(gptj)88model.eval()89model = ipex.llm.optimize(90model,91dtype=torch.float,92quantization_config=qconfig,93inplace=True,94low_precision_checkpoint=low_precision_checkpoint,95deployment_mode=False,96)97_IPEXAttentionCPU = (98ipex.transformers.models.cpu.modules.attentions._IPEXAttentionCPU99)100_IPEXDecoderLayerCPU = (101ipex.transformers.models.cpu.modules.decoder._IPEXDecoderLayerCPU102)103WeightOnlyQuantizedLinear = ipex.nn.modules.WeightOnlyQuantizedLinear104assert model.transformer.h[0].attn.__class__ is _IPEXAttentionCPU105assert model.transformer.h[0].__class__ is _IPEXDecoderLayerCPU106layers_to_check = [107model.transformer.h[0].attn.out_proj,108model.transformer.h[0].linear_add_add.linear,109model.transformer.h[0].linear_gelu.linear,110]111# concat linear is unsupported with act_order112if not act_order:113layers_to_check.append(114model.transformer.h[0].attn.concat_qkv.concat_linear115)116assert all(117mod.__class__ is WeightOnlyQuantizedLinear118for mod in layers_to_check119)120
121# Ensure model can run without errors122with torch.no_grad():123example_inputs = _get_gptj_example_inputs()124# the optimized model is ipex_m.trace_graph125model(*example_inputs)126
127
128if __name__ == "__main__":129test = unittest.main()130run_tests()131