3
from transformers import LlamaForCausalLM, LlamaTokenizer
4
from ultrachat_dataset import load_raw_data, PromptIterableDataset, collator
5
from transformers.optimization import get_linear_schedule_with_warmup
7
from torch.utils.data import DataLoader
9
from functools import partial
14
def get_model_tokenizer(args):
15
model = LlamaForCausalLM.from_pretrained(args.model_name_or_path)
16
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
17
tokenizer.add_special_tokens({'pad_token': "<pad>"})
18
model.resize_token_embeddings(len(tokenizer))
19
model = bmt.BMTrainModelWrapper(model)
20
return model, tokenizer
22
def get_optimizer(args, model):
23
optimizer = bmt.optim.AdamOffloadOptimizer(
24
model.parameters(), weight_decay=args.weight_decay
29
def get_learning_rate_scheduler(args, optimizer):
30
if args.lr_decay_iters is None:
31
args.lr_decay_iters = args.train_iters
32
if args.lr_decay_style == "linear":
33
lr_scheduler = bmt.lr_scheduler.Linear(
36
warmup_iter=int(args.warmup_ratio * args.train_iters),
37
end_iter=args.lr_decay_iters,
38
num_iter=args.start_step,
40
elif args.lr_decay_style == "cosine":
42
lr_scheduler = bmt.lr_scheduler.Cosine(
45
warmup_iter=int(args.warmup_ratio * args.train_iters),
46
end_iter=args.lr_decay_iters,
47
num_iter=args.start_step,
50
elif args.lr_decay_style == "noam":
52
lr_scheduler = bmt.lr_scheduler.Noam(
55
warmup_iter=int(args.warmup_ratio * args.train_iters),
56
end_iter=args.lr_decay_iters,
57
num_iter=args.start_step,
60
raise NotImplementedError
64
def setup_model_and_optimizer(args):
65
model, tokenizer = get_model_tokenizer(args)
67
optimizer = get_optimizer(args, model)
68
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
70
return tokenizer, model, optimizer, lr_scheduler
81
if args.wandb and bmt.rank() == 0:
84
if args.tensorboard is not None and bmt.rank() == 0:
85
from torch.utils.tensorboard import SummaryWriter
86
import distutils.version # noqa: F401
88
if not os.path.exists(args.tensorboard):
89
os.makedirs(args.tensorboard)
90
writer = SummaryWriter(log_dir=args.tensorboard)
92
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
93
optim_manager = bmt.optim.OptimManager(loss_scale=2**10)
94
optim_manager.add_optimizer(optimizer, lr_scheduler)
98
original_dataset = load_raw_data(args.data_file)
99
print("total training instance number:", len(original_dataset))
103
bmt.print_rank("Model memory")
104
bmt.print_rank(torch.cuda.memory_summary())
106
avg_time_recorder = bmt.utils.AverageRecorder()
107
avg_loss_recorder = bmt.utils.AverageRecorder()
110
for epoch in range(args.epochs):
111
indices = torch.randperm(len(original_dataset))
112
dataset = [original_dataset[i] for i in indices]
114
data_per_gpu = len(dataset) // bmt.world_size()
115
dataset = dataset[bmt.rank() * data_per_gpu : (bmt.rank() + 1) * data_per_gpu]
117
dataset = PromptIterableDataset(dataset, tokenizer = tokenizer, max_seq_length = args.max_seq_length, teacher_forcing=True, truncate_method="tail")
118
dataloader = DataLoader(dataset, batch_size=args.batch_size_per_device, collate_fn=partial(collator, tokenizer))
120
if global_step >= args.train_iters:
122
progress_bar = tqdm(range(len(dataloader)), disable=not bmt.rank()==0, desc=f"epoch {epoch}")
124
for step, inputs in enumerate(dataloader):
127
with bmt.inspect.inspect_tensor() as inspector:
129
inputs[k] = inputs[k].cuda()
130
output = model(**inputs)
133
global_loss = bmt.sum_loss(loss).item()
135
optim_manager.backward(loss)
138
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(dataloader) - 1:
139
optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=args.clip_grad)
142
optim_manager.zero_grad()
146
progress_bar.update(1)
148
# record time and loss
149
iteration_time = time.time() - st
151
avg_time_recorder.record(iteration_time)
152
avg_loss_recorder.record(global_loss)
154
# print time and loss
155
if global_step % args.logging_step == 0:
157
"| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} | time: {:.4f}".format(
160
avg_loss_recorder.value,
161
lr_scheduler.current_lr,
162
avg_time_recorder.value
165
if args.wandb and bmt.rank() == 0:
168
"average_loss": avg_loss_recorder.value,
169
"lr": lr_scheduler.current_lr,
171
if args.tensorboard and bmt.rank() == 0:
172
writer.add_scalar("Loss/train", global_loss, global_step)
173
writer.add_scalar("average_Loss/train", avg_loss_recorder.value, global_step)
174
writer.add_scalar("lr/train", lr_scheduler.current_lr, global_step)
178
if global_step % args.save_step == 0:
179
os.makedirs(f"ultrachat_{args.model}/step_{global_step}", exist_ok=True)
181
bmt.save(model, f"ultrachat_{args.model}/step_{global_step}/checkpoint.pt")
184
torch.save(optimizer.state_dict(), f"ultrachat_{args.model}/step_{global_step}/optimizer.pt")
185
torch.save(lr_scheduler.state_dict(), f"ultrachat_{args.model}/step_{global_step}/scheduler.pt")
187
if global_step == args.train_iters:
190
bmt.save(model, f"ultrachat_{args.model}/final.pt")
194
if __name__ == "__main__":
195
parser = argparse.ArgumentParser("")
196
parser.add_argument("--lr", type=float, default=1e-5)
197
parser.add_argument("--model", type=str, default='llama')
198
parser.add_argument("--model_name_or_path", default='/path/to/huggingface/llama')
199
parser.add_argument("--epochs", default=3, type=int)
200
parser.add_argument("--seed", default=0, type=int)
202
parser.add_argument("--max_seq_length", default=2048, type=int)
203
parser.add_argument("--batch_size_per_device", default=2, type=int)
204
parser.add_argument("--logging_step", default=100, type=int)
205
parser.add_argument("--save_step", default=50000, type=int)
206
parser.add_argument("--data_file", default="../data/processed/data.json", type=str)
207
parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
208
parser.add_argument("--wandb", action="store_true")
209
parser.add_argument("--with_eval", action="store_true")
211
parser.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
213
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay rate")
214
parser.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
216
parser.add_argument("--train-iters", type=int, default=2000000)
228
choices=["constant", "linear", "cosine", "exponential", "noam"],
229
help="learning rate decay function",
231
parser.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
233
"--start-step", type=int, default=0, help="step to start or continue training"
235
parser.add_argument("--tensorboard", type=str, default=None, help="lr decay steps")
238
args = parser.parse_args()