5
from contextlib import suppress
9
from torch.utils.data.distributed import DistributedSampler
10
import torch.distributed as dist
13
from transformers.models.idefics.processing_idefics import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
15
print("Failed to import Idefics processing module.")
18
def truncate_text(path, keep_start=10, keep_end=10, truncate_to="..."):
19
if len(path) <= (keep_start + keep_end + len(truncate_to)):
21
return path[:keep_start] + truncate_to + path[-keep_end:]
24
def master_print(*args, **kwargs):
25
if dist.is_available() and dist.is_initialized():
26
rank = dist.get_rank()
28
print(*args, **kwargs)
30
print(*args, **kwargs)
33
def random_seed(seed=42, rank=0):
34
torch.manual_seed(seed + rank)
35
np.random.seed(seed + rank)
36
random.seed(seed + rank)
39
def get_cast_dtype(precision: str):
41
if precision == "bf16":
42
cast_dtype = torch.bfloat16
43
elif precision == "fp16":
44
cast_dtype = torch.float16
48
def get_autocast(precision):
49
if precision == "amp":
50
return torch.cuda.amp.autocast
51
elif precision == "amp_bfloat16" or precision == "amp_bf16":
53
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
54
elif precision == "fp16":
55
return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
60
def get_checkpoint(model):
61
state_dict = model.state_dict()
63
for name, p in model.named_parameters():
64
if not p.requires_grad:
70
def get_checkpoint_deepspeed_zero3(args, model):
73
for name, p in model.named_parameters():
75
state_dict[name] = p.data
83
class AverageMeter(object):
84
"""Computes and stores the average and current value"""
95
def update(self, val, n=1):
99
self.avg = self.sum / self.count
102
class DistributedProxySampler(DistributedSampler):
103
"""Sampler that restricts data loading to a subset of input sampler indices.
105
It is especially useful in conjunction with
106
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
107
process can pass a DistributedSampler instance as a DataLoader sampler,
108
and load a subset of the original dataset that is exclusive to it.
111
Input sampler is assumed to be of constant size.
114
sampler: Input data sampler.
115
num_replicas (optional): Number of processes participating in
116
distributed training.
117
rank (optional): Rank of the current process within num_replicas.
120
def __init__(self, sampler, num_replicas=None, rank=None):
121
super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
122
self.sampler = sampler
126
torch.manual_seed(self.epoch)
127
indices = list(self.sampler)
130
indices += indices[: (self.total_size - len(indices))]
131
if len(indices) != self.total_size:
132
raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
135
indices = indices[self.rank : self.total_size : self.num_replicas]
136
if len(indices) != self.num_samples:
137
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))
143
def get_image_attention_mask(output_input_ids, max_num_images, tokenizer, include_image=True):
147
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, tokenizer)
148
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
151
image_attention_mask = torch.zeros(output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool)
152
return image_attention_mask
155
def verify_yaml(args):
160
result = subprocess.run(["pytest", "-m", "prerun", f"--yaml-path={args.training_data_yaml}"])
162
if result.returncode != 0:
163
print("YAML verification failed!")
167
def get_grouped_params(model, wd):
168
params_with_wd, params_without_wd = [], []
171
return "gated_cross_attn_layer" in x and "ff_gate" not in x and "attn_gate" not in x and "norm" not in x and "bias" not in x
173
for n, p in model.named_parameters():
176
params_with_wd.append(p)
178
params_without_wd.append(p)
181
{"params": params_with_wd, "weight_decay": wd},
182
{"params": params_without_wd, "weight_decay": 0.0},
186
def save_checkpoint(epoch, model, args, accelerator, unwrapped_model=None, global_step=None):
187
"""Save a checkpoint for the model."""
189
if not os.path.exists(args.external_save_dir):
190
os.makedirs(args.external_save_dir)
192
if unwrapped_model is None:
193
unwrapped_model = accelerator.unwrap_model(model)
197
checkpoint_path = f"{args.external_save_dir}/checkpoint_steps_{global_step}.pt"
199
"steps": global_step,
200
"model_state_dict": get_checkpoint(unwrapped_model),
203
checkpoint_path = f"{args.external_save_dir}/checkpoint_{epoch}.pt"
204
checkpoint_dict = {"model_state_dict": get_checkpoint(unwrapped_model)}
208
print(f"Saving checkpoint to {checkpoint_path}")
209
accelerator.save(checkpoint_dict, checkpoint_path)
212
unwrapped_model.config.save_pretrained(args.external_save_dir)
215
if args.delete_previous_checkpoint:
217
prev_checkpoint_path = f"{args.external_save_dir}/checkpoint_step_{global_step-args.save_steps_interval}.pt"
218
if os.path.exists(prev_checkpoint_path):
219
os.remove(prev_checkpoint_path)
221
os.remove(f"{args.external_save_dir}/checkpoint_{epoch-1}.pt")
224
def save_checkpoint(checkpoint_dict, save_path, is_main_process, save_function):
225
"""Helper function to save the checkpoint."""
226
save_function(checkpoint_dict, f"{save_path}/final_weights.pt", is_main_process=is_main_process)
229
def save_pretrained(component, save_path, is_main_process, save_function):
230
"""Helper function to save pretrained components."""
231
component.save_pretrained(save_path, is_main_process=is_main_process, save_function=save_function, safe_serialization=False)
234
def save_final_weights(model, args, accelerator, processor=None, tokenizer=None):
235
"""Save final weights of the model."""
236
unwrapped_model = accelerator.unwrap_model(model)
237
is_main_process = accelerator.is_main_process
238
save_path = args.external_save_dir
239
model_name = args.model_name.lower()
241
unwrapped_model.config.save_pretrained(save_path)
243
if args.save_hf_model:
244
save_pretrained(unwrapped_model, save_path, is_main_process, accelerator.save)
246
if "idefics" in model_name or "fuyu" in model_name:
247
save_pretrained(processor, save_path, is_main_process, accelerator.save)
249
if "llama2" in model_name:
250
save_pretrained(tokenizer, save_path, is_main_process, accelerator.save)
253
if accelerator.distributed_type == "DEEPSPEED" and accelerator.state.deepspeed_plugin.zero_stage == 3:
254
checkpoint_dict = accelerator.get_state_dict(model)
256
checkpoint_dict = get_checkpoint(model=unwrapped_model)
258
if accelerator.distributed_type == "DEEPSPEED" and accelerator.state.deepspeed_plugin.zero_stage == 3:
259
trainable_params_name = [name for name, p in unwrapped_model.named_parameters() if p.requires_grad]
260
checkpoint_dict = {k: v for k, v in checkpoint_dict.items() if k in trainable_params_name}
262
save_checkpoint(checkpoint_dict, save_path, is_main_process, accelerator.save)
265
def get_weights_for_dataloaders(dataloaders):
266
total_samples = sum(len(dataloader.dataset) for dataloader in dataloaders)
267
weights = [len(dataloader.dataset) / total_samples for dataloader in dataloaders]
271
def get_next_dataloader(dataloader_iterators, weights):
272
chosen_dataloader_index = np.random.choice(len(dataloader_iterators), p=weights)
273
return dataloader_iterators[chosen_dataloader_index]
276
def find_and_remove_tokens(input_tensor, labels_tensor, attention_mask_tensor, token_id, tokenizer):
277
batch_size, seq_len = input_tensor.size()
282
new_attention_mask_list = []
285
for i in range(batch_size):
286
single_input = input_tensor[i, :]
287
single_label = labels_tensor[i, :]
288
single_attention_mask = attention_mask_tensor[i, :]
291
new_single_input = torch.masked_select(single_input, single_input != token_id)
292
new_single_label = torch.masked_select(single_label, single_input != token_id)
293
new_single_attention_mask = torch.masked_select(single_attention_mask, single_input != token_id)
296
new_input_list.append(new_single_input)
297
new_labels_list.append(new_single_label)
298
new_attention_mask_list.append(new_single_attention_mask)
301
new_input = torch.nn.utils.rnn.pad_sequence(new_input_list, batch_first=True, padding_value=tokenizer.pad_token_id)
302
new_labels = torch.nn.utils.rnn.pad_sequence(new_labels_list, batch_first=True, padding_value=-100)
303
new_attention_mask = torch.nn.utils.rnn.pad_sequence(new_attention_mask_list, batch_first=True, padding_value=0)
305
return new_input, new_labels, new_attention_mask
308
def delete_tensors_from_dict(d):
309
"""Recursively delete tensors from a nested dictionary."""
311
for k, v in d.items():
312
if isinstance(v, torch.Tensor):
313
keys_to_delete.append(k)
314
elif isinstance(v, list):
315
new_list = [item for item in v if not isinstance(item, torch.Tensor)]
317
elif isinstance(v, dict):
318
delete_tensors_from_dict(v)
319
for key in keys_to_delete: