lmops

Форк
0
/
convert_mp.py 
94 строки · 4.2 Кб
1
#coding:utf-8
2
import torch
3
import argparse
4
import os
5
from transformers import AutoModelForCausalLM
6
from transformers import (
7
    decrease_mp_opt, increase_mp_opt,
8
    decrease_mp_gptj, increase_mp_gptj,
9
    decrease_mp_llama, increase_mp_llama,
10
    decrease_mp_mistral, increase_mp_mistral,
11
    decrease_mp_qwen, increase_mp_qwen,
12
)
13

14
func_map = {
15
    "opt": (decrease_mp_opt, increase_mp_opt),
16
    "gptj": (decrease_mp_gptj, increase_mp_gptj),
17
    "llama": (decrease_mp_llama, increase_mp_llama),
18
    "llama2": (decrease_mp_llama, increase_mp_llama),
19
    "mistral": (decrease_mp_mistral, increase_mp_mistral),
20
    "qwen": (decrease_mp_qwen, increase_mp_qwen),
21
}
22

23

24
def main():
25
    parser = argparse.ArgumentParser("Change the tensor parallel of a model.")
26

27
    parser.add_argument("--input_path", type=str)
28
    parser.add_argument("--model_type", type=str, default="opt")
29
    parser.add_argument("--source_mp_size", type=int, default=1)
30
    parser.add_argument("--target_mp_size", type=int, default=2)
31
    # parser.add_argument("--save_path", type=str)
32
    parser.add_argument("--half", action="store_true")
33
    parser.add_argument("--exist_ok", action="store_true")
34

35
    args = parser.parse_args()
36
    
37
    decrease_mp, increase_mp = func_map[args.model_type]
38

39
    if args.source_mp_size == 1:
40
        assert args.target_mp_size > args.source_mp_size
41
        args.save_path = os.path.join(args.input_path, f"mp{args.target_mp_size}")
42
        assert args.exist_ok or not any([os.path.exists(os.path.join(args.save_path, f"pytorch_model_{i}.bin")) for i in range(args.target_mp_size)])
43
        os.makedirs(args.save_path, exist_ok=True)
44
        if args.model_type=='qwen':
45
            model_hf =  AutoModelForCausalLM.from_pretrained(
46
                args.input_path,
47
                use_flash_attn=False,
48
                fp16=True if args.half else False,
49
                fp32=True if not args.half else False,
50
                bf16=False,
51
            ).state_dict()
52
        else:
53
            model_hf = AutoModelForCausalLM.from_pretrained(args.input_path, torch_dtype=torch.float16).state_dict()
54
        d_list = increase_mp(model_hf, args.target_mp_size, half=args.half)
55
        for i, d in enumerate(d_list):
56
            torch.save(d, os.path.join(args.save_path, f"pytorch_model_{i}.bin"))
57
    elif args.target_mp_size == 1:
58
        assert args.source_mp_size > args.target_mp_size
59
        args.save_path = args.input_path
60
        assert args.exist_ok or not os.path.exists(os.path.join(args.save_path, "pytorch_model.bin"))
61
        ckpt_path = os.path.join(args.input_path, f"mp{args.source_mp_size}")
62
        d_list = [torch.load(os.path.join(ckpt_path, f"pytorch_model_{i}.bin"), map_location="cpu") for i in range(args.source_mp_size)]
63
        d = decrease_mp(d_list, half=args.half)
64
        torch.save(d, os.path.join(args.save_path, "pytorch_model.bin"))
65
    else:
66
        args.save_path = os.path.join(args.input_path, f"mp{args.target_mp_size}")
67
        assert args.exist_ok or not any([os.path.exists(os.path.join(args.save_path, f"pytorch_model_{i}.bin")) for i in range(args.target_mp_size)])
68
        
69
        ckpt_path = os.path.join(args.input_path, f"mp{args.source_mp_size}")
70
        d_list = [torch.load(os.path.join(ckpt_path, f"pytorch_model_{i}.bin"), map_location="cpu") for i in range(args.source_mp_size)]
71
        d = decrease_mp(d_list, half=args.half)
72
        
73
        torch.save(d, os.path.join(args.input_path, "pytorch_model.bin"))
74
        
75
        os.makedirs(args.save_path, exist_ok=True)
76
        if args.model_type=='qwen':
77
            model_hf =  AutoModelForCausalLM.from_pretrained(
78
                args.input_path,
79
                use_flash_attn=False,
80
                fp16=True if args.half else False,
81
                fp32=True if not args.half else False,
82
                bf16=False,
83
            ).state_dict()
84
        else:
85
            model_hf = AutoModelForCausalLM.from_pretrained(args.input_path, torch_dtype=torch.float16).state_dict()
86
        d_list = increase_mp(model_hf, args.target_mp_size, half=args.half)
87
        for i, d in enumerate(d_list):
88
            torch.save(d, os.path.join(args.save_path, f"pytorch_model_{i}.bin"))
89
        
90
        
91
        
92
    
93
if __name__ == '__main__':
94
    main()
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.