1
""" Main training script """
13
from accelerate import Accelerator
15
from transformers import (
17
get_constant_schedule_with_warmup,
18
get_cosine_schedule_with_warmup,
19
get_linear_schedule_with_warmup,
23
from otter_ai import FlamingoForConditionalGeneration, OtterForConditionalGeneration
25
sys.path.append("../..")
26
from pipeline.mimicit_utils.data import get_data
27
from pipeline.train.distributed import world_info_from_env
28
from pipeline.train.train_utils import AverageMeter, get_checkpoint
30
os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
torch.backends.cuda.matmul.allow_tf32 = True
38
torch.backends.cudnn.allow_tf32 = True
42
parser = argparse.ArgumentParser()
44
"--external_save_dir",
47
help="set to save model to external path",
50
"--resume_from_checkpoint",
52
help="Whether to resume from checkpoint, if set True, will load models from --external_save_dir",
55
"--delete_previous_checkpoint",
57
help="delete previous checkpoint when saving new checkpoint",
63
help="used to name saving directory and wandb run",
68
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
73
help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
75
parser.add_argument("--train_num_samples_mmc4", type=int, default=100)
76
parser.add_argument("--train_num_samples_laion", type=int, default=100)
77
parser.add_argument("--batch_size_mmc4", type=int, default=8)
78
parser.add_argument("--batch_size_laion", type=int, default=8)
79
parser.add_argument("--workers", type=int, default=8)
80
parser.add_argument("--dataset_resampled", action="store_true")
82
"--mmc4_textsim_threshold",
85
help="threshold for filtering images in mmc4 based on image-text similarity",
89
parser.add_argument("--offline", action="store_true")
90
parser.add_argument("--num_epochs", type=int, default=1)
91
parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps")
93
"--checkpointing_steps",
96
help="checkpointing every n steps",
100
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
102
"--pretrained_model_name_or_path",
104
help="path to huggingface model or model identifier from local path or huggingface.co",
107
parser.add_argument("--seed", type=int, default=42)
108
parser.add_argument("--learning_rate", default=1e-4, type=float)
113
help="constant, linear, or cosine",
115
parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
116
parser.add_argument("--loss_multiplier_laion", type=float, default=0.2)
117
parser.add_argument("--warmup_steps", default=1000, type=int)
118
parser.add_argument("--warmup_steps_ratio", default=None, type=float)
119
parser.add_argument("--weight_decay", default=0.1, type=float)
122
choices=["amp_bf16", "amp_bfloat16", "bf16", "amp", "fp16", "fp32"],
124
help="Floating point precision.",
131
help="url used to set up distributed training",
133
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
135
"--no-set-device-rank",
138
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
141
parser.add_argument("--mask_lm_head", action="store_true")
146
help="the maximum src sequence length",
152
help="the maximum target sequence length",
154
parser.add_argument("--patch-image-size", type=int, default=224)
156
parser.add_argument("--save_hf_model", default=False, action="store_true")
158
parser.add_argument("--report_to_wandb", default=False, action="store_true")
168
"--save_checkpoints_to_wandb",
171
help="save checkpoints to wandb",
176
def random_seed(seed=42, rank=0):
177
torch.manual_seed(seed + rank)
178
np.random.seed(seed + rank)
179
random.seed(seed + rank)
195
num_batches_per_epoch_laion = laion_loader.num_batches
196
num_batches_per_epoch_mmc4 = mmc4_loader.num_batches
198
assert num_batches_per_epoch_laion == num_batches_per_epoch_mmc4, "Number of batches in laion and mmc4 datasets must be the same"
200
num_batches_per_epoch = num_batches_per_epoch_mmc4
201
total_training_steps = num_batches_per_epoch * args.num_epochs
203
media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
204
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
205
answer_token_id = tokenizer("<answer>", add_special_tokens=False)["input_ids"][-1]
210
step_time_m = AverageMeter()
211
data_time_m = AverageMeter()
215
for num_steps, (batch_laion, batch_mmc4) in tqdm(
216
enumerate(zip(laion_loader, mmc4_loader)),
217
disable=args.rank != 0,
218
total=total_training_steps,
219
initial=(epoch * num_batches_per_epoch),
221
data_time_m.update(time.time() - end)
223
global_step = num_steps + epoch * num_batches_per_epoch
227
images = batch_laion[0].to(device_id, non_blocking=True).unsqueeze(1).unsqueeze(1)
229
input_ids = batch_laion[1][0].to(device_id, non_blocking=True)
230
attention_mask = batch_laion[1][1].to(device_id, non_blocking=True)
232
labels = input_ids.clone()
233
labels[labels == tokenizer.pad_token_id] = -100
235
labels[labels == media_token_id] = -100
238
with accelerator.autocast():
242
attention_mask=attention_mask,
266
accelerator.backward(args.loss_multiplier_laion * loss_laion)
267
total_losses.append(args.loss_multiplier_laion * loss_laion)
270
images = batch_mmc4[0].to(device_id, non_blocking=True).unsqueeze(2)
271
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
272
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
275
labels = input_ids.clone()
276
labels[labels == tokenizer.pad_token_id] = -100
279
for i in range(labels.shape[0]):
282
while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id:
283
labels[i][label_idx] = -100
287
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
288
for endofchunk_idx in endofchunk_idxs:
289
token_idx = endofchunk_idx + 1
290
while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id:
291
labels[i][token_idx] = -100
294
labels[labels == media_token_id] = -100
298
with accelerator.autocast():
302
attention_mask=attention_mask,
318
accelerator.backward(args.loss_multiplier_mmc4 * loss_mmc4)
319
total_losses.append(args.loss_multiplier_mmc4 * loss_mmc4)
321
total_loss_sum = sum(total_losses)
322
mean_loss = total_loss_sum / len(total_losses)
325
def mask_embedding(m):
326
if m.weight.requires_grad:
327
zero_mask = torch.zeros_like(m.weight.grad)
329
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
330
zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id])
331
m.weight.grad = m.weight.grad * zero_mask
333
if args.mask_lm_head:
334
unwrapped_model = accelerator.unwrap_model(model)
335
if unwrapped_model.lang_encoder.__class__.__name__ == "MPTForCausalLM":
336
unwrapped_model.lang_encoder.transformer.wte.apply(mask_embedding)
337
elif unwrapped_model.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
338
unwrapped_model.lang_encoder.model.embed_tokens.apply(mask_embedding)
339
unwrapped_model.lang_encoder.lm_head.apply(mask_embedding)
341
if accelerator.sync_gradients:
342
accelerator.clip_grad_norm_(model.parameters(), 1.0)
346
optimizer.zero_grad()
349
step_time_m.update(time.time() - end)
352
if accelerator.sync_gradients:
353
if args.rank == 0 and args.report_to_wandb:
355
mmc4_samples_per_second = args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val
356
mmc4_samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val
357
laion_samples_per_second = args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val
358
laion_samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val
361
"data_time": data_time_m.avg,
362
"step_time": step_time_m.avg,
363
"mmc4_samples_per_second": mmc4_samples_per_second,
364
"mmc4_samples_per_second_per_gpu": mmc4_samples_per_second_per_gpu,
365
"laion_samples_per_second": laion_samples_per_second,
366
"laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu,
367
"lr": optimizer.param_groups[0]["lr"],
376
"mmc4_loss": loss_mmc4.item(),
377
"laion_loss": loss_laion.item(),
378
"mean_loss": mean_loss.item(),
379
"global_step": global_step // args.gradient_accumulation_steps,
385
if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
386
print(f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Mean Loss: {mean_loss.item():.3f}")
388
if ((num_steps + 1) % args.checkpointing_steps == 0) and args.rank == 0:
389
if not os.path.exists(args.external_save_dir):
390
os.makedirs(args.external_save_dir)
392
unwrapped_model = accelerator.unwrap_model(model)
395
"model_state_dict": get_checkpoint(unwrapped_model),
396
"optimizer_state_dict": optimizer.state_dict(),
397
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
399
print(f"Saving checkpoint to {args.external_save_dir}/checkpoint_steps{num_steps + 1}.pt")
402
f"{args.external_save_dir}/checkpoint_steps{num_steps + 1}.pt",
405
print(f"Saving config to {args.external_save_dir}/config.json")
406
unwrapped_model.config.save_pretrained(args.external_save_dir)
407
if args.delete_previous_checkpoint:
408
if (num_steps + 1) // args.checkpointing_steps >= 2:
409
previous_checkpoint_path = f"{args.external_save_dir}/checkpoint_steps{num_steps + 1 - args.checkpointing_steps}.pt"
410
if os.path.exists(previous_checkpoint_path):
411
os.remove(previous_checkpoint_path)
415
parser = parse_args()
418
args = parser.parse_args()
420
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
421
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
424
os.environ["WANDB_MODE"] = "offline"
425
os.environ["TRANSFORMERS_OFFLINE"] = "1"
427
args.local_rank, args.rank, args.world_size = world_info_from_env()
428
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
430
device_id = accelerator.device
432
random_seed(args.seed)
433
if args.pretrained_model_name_or_path is not None:
434
accelerator.print(f"Loading pretrained model from {args.pretrained_model_name_or_path}")
435
if "otter" in args.run_name.lower():
436
model = OtterForConditionalGeneration.from_pretrained(
437
args.pretrained_model_name_or_path,
439
local_files_only=args.offline,
441
elif "flamingo" in args.run_name.lower():
442
if accelerator.num_processes > 1:
443
model = FlamingoForConditionalGeneration.from_pretrained(
444
args.pretrained_model_name_or_path,
445
device_map={"": device_id},
446
local_files_only=args.offline,
449
model = FlamingoForConditionalGeneration.from_pretrained(
450
args.pretrained_model_name_or_path,
452
local_files_only=args.offline,
454
model.text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]})
458
accelerator.wait_for_everyone()
460
if model.lang_encoder.__class__.__name__ != "MPTForCausalLM":
461
model.lang_encoder.resize_token_embeddings(len(model.text_tokenizer))
463
args.tokenizer = model.text_tokenizer
464
tokenizer = model.text_tokenizer
465
random_seed(args.seed, args.rank)
467
image_processor = CLIPImageProcessor()
469
mmc4_dataset = get_data(args, image_processor, tokenizer, "mmc4")
470
laion_dataset = get_data(args, image_processor, tokenizer, "laion")
472
def get_grouped_params(model):
473
params_with_wd, params_without_wd = [], []
476
return "gated_cross_attn_layer" in x and "ff_gate" not in x and "attn_gate" not in x and "norm" not in x and "bias" not in x
478
for n, p in model.named_parameters():
481
params_with_wd.append(p)
483
params_without_wd.append(p)
486
{"params": params_with_wd, "weight_decay": args.weight_decay},
487
{"params": params_without_wd, "weight_decay": 0.0},
491
total_training_steps = mmc4_dataset.dataloader.num_batches * args.num_epochs
493
resume_from_epoch = 0
495
args.external_save_dir = os.path.join(args.external_save_dir, args.run_name) if args.external_save_dir else args.run_name
496
if os.path.exists(f"{args.external_save_dir}") and args.resume_from_checkpoint is True:
497
checkpoint_list = glob.glob(f"{args.external_save_dir}/checkpoint_*.pt")
498
if len(checkpoint_list) == 0:
499
print(f"Found no checkpoints for run {args.external_save_dir}.")
501
resume_from_checkpoint_path = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
502
print(f"Found checkpoint {resume_from_checkpoint_path} for run {args.external_save_dir}.")
505
print(f"Loading checkpoint from {resume_from_checkpoint_path}")
506
checkpoint = torch.load(resume_from_checkpoint_path, map_location="cpu")
507
model.load_state_dict(checkpoint["model_state_dict"], False)
508
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
509
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
510
resume_from_epoch = checkpoint["epoch"] + 1
512
optimizer = torch.optim.AdamW(get_grouped_params(model), lr=args.learning_rate)
515
print(f"Total training steps: {total_training_steps}")
517
args.warmup_steps = total_training_steps * args.warmup_steps_ratio if args.warmup_steps_ratio is not None else args.warmup_steps
519
if args.lr_scheduler == "linear":
520
lr_scheduler = get_linear_schedule_with_warmup(
522
num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps,
523
num_training_steps=total_training_steps // args.gradient_accumulation_steps,
525
elif args.lr_scheduler == "cosine":
526
lr_scheduler = get_cosine_schedule_with_warmup(
528
num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps,
529
num_training_steps=total_training_steps // args.gradient_accumulation_steps,
532
lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
534
if args.rank == 0 and args.report_to_wandb:
536
project=args.wandb_project,
537
entity=args.wandb_entity,
542
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
545
if accelerator.num_processes > 1:
546
lr_scheduler.split_batches = True
550
for epoch in range(resume_from_epoch, args.num_epochs):
551
laion_dataset.set_epoch(epoch)
552
laion_loader = laion_dataset.dataloader
554
mmc4_dataset.set_epoch(epoch)
555
mmc4_loader = mmc4_dataset.dataloader
563
lr_scheduler=lr_scheduler,
564
mmc4_loader=mmc4_loader,
565
laion_loader=laion_loader,
566
accelerator=accelerator,
571
if not os.path.exists(args.external_save_dir):
572
os.makedirs(args.external_save_dir)
574
unwrapped_model = accelerator.unwrap_model(model)
577
"model_state_dict": get_checkpoint(unwrapped_model),
578
"optimizer_state_dict": optimizer.state_dict(),
579
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
581
print(f"Saving checkpoint to {args.external_save_dir}/checkpoint_epoch{epoch}.pt")
582
accelerator.save(checkpoint_dict, f"{args.external_save_dir}/checkpoint_epoch{epoch}.pt")
584
unwrapped_model.config.save_pretrained(args.external_save_dir)
585
if args.delete_previous_checkpoint:
587
os.remove(f"{args.external_save_dir}/checkpoint_epoch{epoch-1}.pt")
589
accelerator.wait_for_everyone()
591
accelerator.wait_for_everyone()
593
if not os.path.exists(args.external_save_dir):
594
os.makedirs(args.external_save_dir)
596
unwrapped_model = accelerator.unwrap_model(model)
598
get_checkpoint(model=unwrapped_model),
599
f"{args.external_save_dir}/final_weights.pt",
602
unwrapped_model.config.save_pretrained(args.external_save_dir)
604
if args.report_to_wandb and args.save_checkpoints_to_wandb:
605
wandb.save(f"{args.external_save_dir}/final_weights.pt")
606
if args.save_hf_model:
607
unwrapped_model.save_pretrained(f"{args.external_save_dir}")
610
if __name__ == "__main__":