UltraChat

Форк
0
/
train_bm.py 
240 строк · 8.6 Кб
1
import argparse
2
import torch
3
from transformers import LlamaForCausalLM, LlamaTokenizer
4
from ultrachat_dataset import load_raw_data, PromptIterableDataset, collator
5
from transformers.optimization import get_linear_schedule_with_warmup
6
from tqdm import tqdm
7
from torch.utils.data import DataLoader
8
import bmtrain as bmt
9
from functools import partial
10
import time
11
import os
12
import wandb
13

14
def get_model_tokenizer(args):
15
    model = LlamaForCausalLM.from_pretrained(args.model_name_or_path)
16
    tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
17
    tokenizer.add_special_tokens({'pad_token': "<pad>"})
18
    model.resize_token_embeddings(len(tokenizer))
19
    model = bmt.BMTrainModelWrapper(model)
20
    return model, tokenizer
21

22
def get_optimizer(args, model):
23
    optimizer = bmt.optim.AdamOffloadOptimizer(
24
        model.parameters(), weight_decay=args.weight_decay
25
    )
26
    return optimizer
27

28

29
def get_learning_rate_scheduler(args, optimizer):
30
    if args.lr_decay_iters is None:
31
        args.lr_decay_iters = args.train_iters
32
    if args.lr_decay_style == "linear":
33
        lr_scheduler = bmt.lr_scheduler.Linear(
34
            optimizer,
35
            start_lr=args.lr,
36
            warmup_iter=int(args.warmup_ratio * args.train_iters),
37
            end_iter=args.lr_decay_iters,
38
            num_iter=args.start_step,
39
        )
40
    elif args.lr_decay_style == "cosine":
41
        print("use cosine")
42
        lr_scheduler = bmt.lr_scheduler.Cosine(
43
            optimizer,
44
            start_lr=args.lr,
45
            warmup_iter=int(args.warmup_ratio * args.train_iters),
46
            end_iter=args.lr_decay_iters,
47
            num_iter=args.start_step,
48
        )
49

50
    elif args.lr_decay_style == "noam":
51
        print("use noam")
52
        lr_scheduler = bmt.lr_scheduler.Noam(
53
            optimizer,
54
            start_lr=args.lr,
55
            warmup_iter=int(args.warmup_ratio * args.train_iters),
56
            end_iter=args.lr_decay_iters,
57
            num_iter=args.start_step,
58
        )
59
    else:
60
        raise NotImplementedError
61
    return lr_scheduler
62

63

64
def setup_model_and_optimizer(args):
65
    model, tokenizer = get_model_tokenizer(args)
66
    bmt.synchronize()
67
    optimizer = get_optimizer(args, model)
68
    lr_scheduler = get_learning_rate_scheduler(args, optimizer)
69
    bmt.synchronize()
70
    return tokenizer, model, optimizer, lr_scheduler
71

72

73

74
def train(args):
75

76
    bmt.init_distributed(
77
        seed=args.seed,
78
        zero_level=3,
79
    )
80

81
    if args.wandb and bmt.rank() == 0:
82
        wandb.init()
83
    
84
    if args.tensorboard is not None and bmt.rank() == 0:
85
        from torch.utils.tensorboard import SummaryWriter
86
        import distutils.version  # noqa: F401
87

88
        if not os.path.exists(args.tensorboard):
89
            os.makedirs(args.tensorboard)
90
        writer = SummaryWriter(log_dir=args.tensorboard)
91

92
    tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
93
    optim_manager = bmt.optim.OptimManager(loss_scale=2**10)
94
    optim_manager.add_optimizer(optimizer, lr_scheduler)
95

96
    bmt.synchronize()
97

98
    original_dataset = load_raw_data(args.data_file)
99
    print("total training instance number:", len(original_dataset))
100
    
101

102

103
    bmt.print_rank("Model memory")
104
    bmt.print_rank(torch.cuda.memory_summary())
105

106
    avg_time_recorder = bmt.utils.AverageRecorder()
107
    avg_loss_recorder = bmt.utils.AverageRecorder()
108

109
    global_step = 0
110
    for epoch in range(args.epochs):
111
        indices = torch.randperm(len(original_dataset))
112
        dataset = [original_dataset[i] for i in indices]
113

114
        data_per_gpu = len(dataset) // bmt.world_size()
115
        dataset = dataset[bmt.rank() * data_per_gpu : (bmt.rank() + 1) * data_per_gpu]
116

117
        dataset = PromptIterableDataset(dataset, tokenizer = tokenizer, max_seq_length = args.max_seq_length, teacher_forcing=True, truncate_method="tail")
118
        dataloader = DataLoader(dataset, batch_size=args.batch_size_per_device, collate_fn=partial(collator, tokenizer))
119

120
        if global_step >= args.train_iters:
121
            break
122
        progress_bar = tqdm(range(len(dataloader)), disable=not bmt.rank()==0, desc=f"epoch {epoch}")
123

124
        for step, inputs in enumerate(dataloader):
125
            st = time.time()
126

127
            with bmt.inspect.inspect_tensor() as inspector:
128
                for k in inputs:
129
                    inputs[k] = inputs[k].cuda()
130
                output = model(**inputs)
131
                loss = output.loss
132
            
133
                global_loss = bmt.sum_loss(loss).item()
134

135
                optim_manager.backward(loss)
136

137

138
                if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(dataloader) - 1:
139
                    optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=args.clip_grad)
140

141
                    optim_manager.step()
142
                    optim_manager.zero_grad()
143

144
            
145
            global_step += 1
146
            progress_bar.update(1)
147

148
            # record time and loss
149
            iteration_time = time.time() - st
150

151
            avg_time_recorder.record(iteration_time)
152
            avg_loss_recorder.record(global_loss)
153

154
            # print time and loss
155
            if global_step % args.logging_step == 0:
156
                bmt.print_rank(
157
                    "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} | time: {:.4f}".format(
158
                        global_step,
159
                        global_loss,
160
                        avg_loss_recorder.value,
161
                        lr_scheduler.current_lr,
162
                        avg_time_recorder.value
163
                    )
164
                )
165
                if args.wandb and bmt.rank() == 0:
166
                    wandb.log({
167
                        "loss": global_loss,
168
                        "average_loss": avg_loss_recorder.value,
169
                        "lr": lr_scheduler.current_lr,
170
                    }, step=global_step)
171
                if args.tensorboard and bmt.rank() == 0:
172
                    writer.add_scalar("Loss/train", global_loss, global_step)
173
                    writer.add_scalar("average_Loss/train", avg_loss_recorder.value, global_step)
174
                    writer.add_scalar("lr/train", lr_scheduler.current_lr, global_step)
175

176

177
            # save model
178
            if global_step % args.save_step == 0:
179
                os.makedirs(f"ultrachat_{args.model}/step_{global_step}", exist_ok=True)
180

181
                bmt.save(model, f"ultrachat_{args.model}/step_{global_step}/checkpoint.pt")
182

183
                if bmt.rank() == 0:
184
                    torch.save(optimizer.state_dict(), f"ultrachat_{args.model}/step_{global_step}/optimizer.pt")
185
                    torch.save(lr_scheduler.state_dict(), f"ultrachat_{args.model}/step_{global_step}/scheduler.pt")
186
            
187
            if global_step == args.train_iters:
188
                break
189
    
190
    bmt.save(model, f"ultrachat_{args.model}/final.pt")
191

192
  
193

194
if __name__ == "__main__":
195
    parser = argparse.ArgumentParser("")
196
    parser.add_argument("--lr", type=float, default=1e-5)
197
    parser.add_argument("--model", type=str, default='llama')
198
    parser.add_argument("--model_name_or_path", default='/path/to/huggingface/llama')
199
    parser.add_argument("--epochs", default=3, type=int)
200
    parser.add_argument("--seed", default=0, type=int)
201

202
    parser.add_argument("--max_seq_length", default=2048, type=int)
203
    parser.add_argument("--batch_size_per_device", default=2, type=int)
204
    parser.add_argument("--logging_step", default=100, type=int)
205
    parser.add_argument("--save_step", default=50000, type=int)
206
    parser.add_argument("--data_file", default="../data/processed/data.json", type=str)
207
    parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
208
    parser.add_argument("--wandb", action="store_true")
209
    parser.add_argument("--with_eval", action="store_true")
210

211
    parser.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
212
    # Learning rate.
213
    parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay rate")
214
    parser.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
215

216
    parser.add_argument("--train-iters", type=int, default=2000000)
217

218

219
    parser.add_argument(
220
        "--warmup-ratio",
221
        type=float,
222
        default=0.03,
223
    )
224
    parser.add_argument(
225
        "--lr-decay-style",
226
        type=str,
227
        default="cosine",
228
        choices=["constant", "linear", "cosine", "exponential", "noam"],
229
        help="learning rate decay function",
230
    )
231
    parser.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
232
    parser.add_argument(
233
        "--start-step", type=int, default=0, help="step to start or continue training"
234
    )
235
    parser.add_argument("--tensorboard", type=str, default=None, help="lr decay steps")
236

237

238
    args = parser.parse_args()
239

240
    train(args)
241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.