5
from transformers import AutoModelForCausalLM
6
from transformers import (
7
decrease_mp_opt, increase_mp_opt,
8
decrease_mp_gptj, increase_mp_gptj,
9
decrease_mp_llama, increase_mp_llama,
10
decrease_mp_mistral, increase_mp_mistral,
11
decrease_mp_qwen, increase_mp_qwen,
15
"opt": (decrease_mp_opt, increase_mp_opt),
16
"gptj": (decrease_mp_gptj, increase_mp_gptj),
17
"llama": (decrease_mp_llama, increase_mp_llama),
18
"llama2": (decrease_mp_llama, increase_mp_llama),
19
"mistral": (decrease_mp_mistral, increase_mp_mistral),
20
"qwen": (decrease_mp_qwen, increase_mp_qwen),
25
parser = argparse.ArgumentParser("Change the tensor parallel of a model.")
27
parser.add_argument("--input_path", type=str)
28
parser.add_argument("--model_type", type=str, default="opt")
29
parser.add_argument("--source_mp_size", type=int, default=1)
30
parser.add_argument("--target_mp_size", type=int, default=2)
31
# parser.add_argument("--save_path", type=str)
32
parser.add_argument("--half", action="store_true")
33
parser.add_argument("--exist_ok", action="store_true")
35
args = parser.parse_args()
37
decrease_mp, increase_mp = func_map[args.model_type]
39
if args.source_mp_size == 1:
40
assert args.target_mp_size > args.source_mp_size
41
args.save_path = os.path.join(args.input_path, f"mp{args.target_mp_size}")
42
assert args.exist_ok or not any([os.path.exists(os.path.join(args.save_path, f"pytorch_model_{i}.bin")) for i in range(args.target_mp_size)])
43
os.makedirs(args.save_path, exist_ok=True)
44
if args.model_type=='qwen':
45
model_hf = AutoModelForCausalLM.from_pretrained(
48
fp16=True if args.half else False,
49
fp32=True if not args.half else False,
53
model_hf = AutoModelForCausalLM.from_pretrained(args.input_path, torch_dtype=torch.float16).state_dict()
54
d_list = increase_mp(model_hf, args.target_mp_size, half=args.half)
55
for i, d in enumerate(d_list):
56
torch.save(d, os.path.join(args.save_path, f"pytorch_model_{i}.bin"))
57
elif args.target_mp_size == 1:
58
assert args.source_mp_size > args.target_mp_size
59
args.save_path = args.input_path
60
assert args.exist_ok or not os.path.exists(os.path.join(args.save_path, "pytorch_model.bin"))
61
ckpt_path = os.path.join(args.input_path, f"mp{args.source_mp_size}")
62
d_list = [torch.load(os.path.join(ckpt_path, f"pytorch_model_{i}.bin"), map_location="cpu") for i in range(args.source_mp_size)]
63
d = decrease_mp(d_list, half=args.half)
64
torch.save(d, os.path.join(args.save_path, "pytorch_model.bin"))
66
args.save_path = os.path.join(args.input_path, f"mp{args.target_mp_size}")
67
assert args.exist_ok or not any([os.path.exists(os.path.join(args.save_path, f"pytorch_model_{i}.bin")) for i in range(args.target_mp_size)])
69
ckpt_path = os.path.join(args.input_path, f"mp{args.source_mp_size}")
70
d_list = [torch.load(os.path.join(ckpt_path, f"pytorch_model_{i}.bin"), map_location="cpu") for i in range(args.source_mp_size)]
71
d = decrease_mp(d_list, half=args.half)
73
torch.save(d, os.path.join(args.input_path, "pytorch_model.bin"))
75
os.makedirs(args.save_path, exist_ok=True)
76
if args.model_type=='qwen':
77
model_hf = AutoModelForCausalLM.from_pretrained(
80
fp16=True if args.half else False,
81
fp32=True if not args.half else False,
85
model_hf = AutoModelForCausalLM.from_pretrained(args.input_path, torch_dtype=torch.float16).state_dict()
86
d_list = increase_mp(model_hf, args.target_mp_size, half=args.half)
87
for i, d in enumerate(d_list):
88
torch.save(d, os.path.join(args.save_path, f"pytorch_model_{i}.bin"))
93
if __name__ == '__main__':