LLM-FineTuning-Large-Language-Models
41 строка · 2.4 Кб
1import argparse
2import auto_gptq
3from transformers import GPTQConfig, AutoModelForCausalLM, AutoTokenizer
4
5# Function to get the device based on user input or system capability
6def get_device(device_map):
7if device_map == "auto":
8return "cuda" if torch.cuda.is_available() else "cpu"
9return device_map
10
11# Function to configure and return the quantized model
12# https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig
13def configure_model(model_id, bits, dataset, tokenizer, group_size, device):
14gptq_config = GPTQConfig(bits=bits, dataset=dataset, tokenizer=tokenizer, group_size=group_size, desc_act=True)
15model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=gptq_config)
16model.to(device)
17return model
18
19# Main function to execute the script logic
20def main(model_id, bits, dataset, group_size, device_map):
21device = get_device(device_map)
22tokenizer = AutoTokenizer.from_pretrained(model_id)
23model = configure_model(model_id, bits, dataset, tokenizer, group_size, device)
24
25model_dir = f"{model_id}_quantized_{bits}bit"
26model.save_pretrained(model_dir)
27tokenizer.save_pretrained(model_dir)
28
29if __name__ == "__main__":
30parser = argparse.ArgumentParser(description="Quantize a GPT model.")
31parser.add_argument("--model_id", default="mistralai/Mistral-7B-v0.1", type=str, help="The pretrained model ID.")
32parser.add_argument("--bits", default=4, type=int, help="Number of bits for quantization.")
33parser.add_argument("--dataset", default="wikitext2", type=str, help="The dataset to use.")
34parser.add_argument("--group_size", default=128, type=int, help="Group size for quantization.")
35parser.add_argument("--device_map", default="auto", type=str, help="Device map for loading the model.")
36parser.add_argument("--use_exllama", default="True", type=bool, help="Whether to use exllama backend. Defaults to True if unset. Only works with bits = 4.")
37parser.add_argument("--desc_act", default="False", type=bool, help=" Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly speed up inference but the perplexity may become slightly worse. Also known as act-order.")
38
39args = parser.parse_args()
40
41main(model_id=args.model_id, bits=args.bits, dataset=args.dataset, group_size=args.group_size, device_map=args.device_map)
42