5
import torch.distributed as dist
6
from torch.distributed import get_rank
10
from datetime import timedelta
12
from accelerate import load_checkpoint_and_dispatch, init_empty_weights
13
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
16
from transformers import (
20
ParallelOPTForCausalLM,
21
ParallelLlamaForCausalLM,
22
ParallelGPTJForCausalLM,
23
ParallelGPT2LMHeadModel,
24
ParallelMistralForCausalLM,
25
ParallelQWenLMHeadModel,
30
"opt": ParallelOPTForCausalLM,
31
"gptj": ParallelGPTJForCausalLM,
32
"gpt2": ParallelGPT2LMHeadModel,
33
"llama": ParallelLlamaForCausalLM,
34
"llama2": ParallelLlamaForCausalLM,
35
"mistral": ParallelMistralForCausalLM,
36
"qwen": ParallelQWenLMHeadModel,
42
"""Print arguments."""
44
print('arguments:', flush=True)
45
for arg in vars(args):
46
dots = '.' * (29 - len(arg))
47
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
50
def save_rank(log_str, save_path, rank=0):
51
if not dist.is_initialized() or dist.get_rank() == rank:
52
with open(save_path, "a") as f:
53
f.write(log_str + "\n")
56
def print_rank(*args, rank=0, **kwargs):
57
if not dist.is_initialized() or dist.get_rank() == rank:
58
print(*args, **kwargs)
62
def all_gather(t, dim=0, world_size=None, group=None, op="cat"):
63
if world_size is None:
64
world_size = dist.get_world_size()
65
all_t = [torch.zeros_like(t) for _ in range(world_size)]
66
dist.all_gather(all_t, t, group=group)
68
all_t = torch.cat(all_t, dim=dim)
70
all_t = torch.stack(all_t, dim=dim)
75
def set_random_seed(seed, mp=False):
76
"""Set random seed for reproducability."""
77
seed = dist.get_rank() + seed
78
if seed is not None and seed > 0:
81
torch.manual_seed(seed)
83
mpu.model_parallel_cuda_manual_seed(seed)
86
def init_distributed(args):
87
args.rank = int(os.getenv("RANK", "0"))
88
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
89
args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
92
print(f"using world size: {args.world_size}")
95
device = args.rank % torch.cuda.device_count()
96
if args.local_rank is not None:
97
device = args.local_rank
98
torch.cuda.set_device(device)
100
dist.init_process_group(backend="nccl", timeout=timedelta(minutes=300))
103
def init_distributed_ds(args):
104
args.rank = int(os.getenv("RANK", "0"))
105
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
106
args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
109
print(f"using world size: {args.world_size}")
112
device = args.rank % torch.cuda.device_count()
113
if args.local_rank is not None:
114
device = args.local_rank
115
torch.cuda.set_device(device)
117
deepspeed.init_distributed(timeout=timedelta(minutes=300))
123
init_distributed_ds(args)
125
init_distributed(args)
127
if args.model_parallel:
128
assert dist.get_world_size() % args.model_parallel_size == 0
129
mpu.initialize_model_parallel(args.model_parallel_size)
131
set_random_seed(args.seed, args.model_parallel)
133
if args.save != None:
134
os.makedirs(args.save, exist_ok=True)
138
def get_model(args, device):
139
config = AutoConfig.from_pretrained(args.model_path)
141
st_time = time.time()
142
if args.model_parallel:
143
config.is_model_parallel = True
144
with init_empty_weights():
145
if args.model_type=="qwen":
146
model = parallel_model_map[args.model_type](config).to(torch.bfloat16)
148
model = parallel_model_map[args.model_type](config).half()
149
load_parallel(model, args.model_path)
151
if mpu.get_data_parallel_rank() == 0:
152
print(' > number of parameters on model parallel rank {}: {}'.format(
153
mpu.get_model_parallel_rank(),
154
sum([p.nelement() for p in model.parameters()])), flush=True)
156
config.is_model_parallel = False
157
if args.model_type=="qwen":
158
dtype = torch.float32 if args.fp32 else torch.float16
160
dtype = torch.float32 if args.fp32 else torch.bfloat16
161
model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, device_map={"": device}, torch_dtype=dtype)
163
if args.peft is not None:
164
if args.peft == "lora":
165
model.enable_input_require_grads()
166
if args.peft_path is not None:
167
model = PeftModel.from_pretrained(model, args.peft_path)
169
peft_config = LoraConfig(
170
task_type=TaskType.CAUSAL_LM, inference_mode=(not args.do_train), r=args.peft_lora_r, lora_alpha=args.peft_lora_alpha, lora_dropout=args.peft_lora_dropout
172
model = get_peft_model(model, peft_config)
173
model.print_trainable_parameters()
175
raise NotImplementedError
177
if dist.get_rank() == 0:
178
print(' > number of parameters: {}'.format(
179
sum([p.nelement() for p in model.parameters()])), flush=True)
182
if args.gradient_checkpointing:
183
model.gradient_checkpointing_enable()
185
ed_time = time.time()
187
print_rank(f"Model load time: {ed_time - st_time}s")
192
def get_optimizer_params(args, model: nn.Module):
194
param_optimizer = list(model.named_parameters())
195
no_decay = ['bias', 'ln_f.weight', 'ln_1.weight', 'ln_2.weight', 'ln_cross_attn']
196
optimizer_grouped_parameters = [
197
{'params': [p for n, p in param_optimizer
198
if not any(nd in n for nd in no_decay)]},
199
{'params': [p for n, p in param_optimizer
200
if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
203
return optimizer_grouped_parameters
206
def get_optimizer_params_peft(args, model: nn.Module):
208
param_optimizer = list(model.named_parameters())
209
optimizer_grouped_parameters = [
210
{'params': [p for n, p in param_optimizer if p.requires_grad]},
213
return optimizer_grouped_parameters
216
def get_tokenizer(args):
217
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
218
if args.model_type in ["gpt2", "opt", "llama", "gptj", "llama2", "mistral"]:
219
tokenizer.pad_token_id = tokenizer.eos_token_id
220
elif args.model_type=="qwen":
221
tokenizer.pad_token_id = 151646
222
tokenizer.eos_token_id = 151643
223
tokenizer.pad_token_id = tokenizer.eos_token_id
228
def load_parallel(model, load_dir):
229
mp_rank = mpu.get_model_parallel_rank()
230
assert mpu.get_model_parallel_world_size() != 1
231
checkpoint_name = os.path.join(load_dir, f"mp{mpu.get_model_parallel_world_size()}", f"pytorch_model_{mp_rank}.bin")
232
assert os.path.exists(checkpoint_name), f"{checkpoint_name} does not exist."
233
model = load_checkpoint_and_dispatch(model=model, checkpoint=checkpoint_name, device_map={"": torch.cuda.current_device()}, dtype=torch.float16)
235
print(f"Rank {get_rank()}: {checkpoint_name} loaded.")
238
def save_parallel(model, save_dir):
239
mp_rank = mpu.get_model_parallel_rank()
240
os.makedirs(os.path.join(save_dir, f"mp{mpu.get_model_parallel_world_size()}"), exist_ok=True)
241
checkpoint_name = os.path.join(save_dir, f"mp{mpu.get_model_parallel_world_size()}", f"pytorch_model_{mp_rank}.bin")
242
torch.save(model.state_dict(), checkpoint_name)
243
print(f"Rank {get_rank()}: {checkpoint_name} saved.")