LLM-FineTuning-Large-Language-Models

Форк
0
41 строка · 2.4 Кб
1
import argparse
2
import auto_gptq
3
from transformers import GPTQConfig, AutoModelForCausalLM, AutoTokenizer
4

5
# Function to get the device based on user input or system capability
6
def get_device(device_map):
7
    if device_map == "auto":
8
        return "cuda" if torch.cuda.is_available() else "cpu"
9
    return device_map
10

11
# Function to configure and return the quantized model
12
# https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig
13
def configure_model(model_id, bits, dataset, tokenizer, group_size, device):
14
    gptq_config = GPTQConfig(bits=bits, dataset=dataset, tokenizer=tokenizer, group_size=group_size, desc_act=True)
15
    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=gptq_config)
16
    model.to(device)
17
    return model
18

19
# Main function to execute the script logic
20
def main(model_id, bits, dataset, group_size, device_map):
21
    device = get_device(device_map)
22
    tokenizer = AutoTokenizer.from_pretrained(model_id)
23
    model = configure_model(model_id, bits, dataset, tokenizer, group_size, device)
24

25
    model_dir = f"{model_id}_quantized_{bits}bit"
26
    model.save_pretrained(model_dir)
27
    tokenizer.save_pretrained(model_dir)
28

29
if __name__ == "__main__":
30
    parser = argparse.ArgumentParser(description="Quantize a GPT model.")
31
    parser.add_argument("--model_id", default="mistralai/Mistral-7B-v0.1", type=str, help="The pretrained model ID.")
32
    parser.add_argument("--bits", default=4, type=int, help="Number of bits for quantization.")
33
    parser.add_argument("--dataset", default="wikitext2", type=str, help="The dataset to use.")
34
    parser.add_argument("--group_size", default=128, type=int, help="Group size for quantization.")
35
    parser.add_argument("--device_map", default="auto", type=str, help="Device map for loading the model.")
36
    parser.add_argument("--use_exllama", default="True", type=bool, help="Whether to use exllama backend. Defaults to True if unset. Only works with bits = 4.")
37
    parser.add_argument("--desc_act", default="False", type=bool, help=" Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly speed up inference but the perplexity may become slightly worse. Also known as act-order.")
38

39
    args = parser.parse_args()
40

41
    main(model_id=args.model_id, bits=args.bits, dataset=args.dataset, group_size=args.group_size, device_map=args.device_map)
42

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.