otter

Форк
0
/
train_utils.py 
320 строк · 11.6 Кб
1
import os
2
import random
3
import subprocess
4
import sys
5
from contextlib import suppress
6

7
import numpy as np
8
import torch
9
from torch.utils.data.distributed import DistributedSampler
10
import torch.distributed as dist
11

12
try:
13
    from transformers.models.idefics.processing_idefics import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
14
except ImportError:
15
    print("Failed to import Idefics processing module.")
16

17

18
def truncate_text(path, keep_start=10, keep_end=10, truncate_to="..."):
19
    if len(path) <= (keep_start + keep_end + len(truncate_to)):
20
        return path
21
    return path[:keep_start] + truncate_to + path[-keep_end:]
22

23

24
def master_print(*args, **kwargs):
25
    if dist.is_available() and dist.is_initialized():
26
        rank = dist.get_rank()
27
        if rank == 0:
28
            print(*args, **kwargs)
29
    else:
30
        print(*args, **kwargs)
31

32

33
def random_seed(seed=42, rank=0):
34
    torch.manual_seed(seed + rank)
35
    np.random.seed(seed + rank)
36
    random.seed(seed + rank)
37

38

39
def get_cast_dtype(precision: str):
40
    cast_dtype = None
41
    if precision == "bf16":
42
        cast_dtype = torch.bfloat16
43
    elif precision == "fp16":
44
        cast_dtype = torch.float16
45
    return cast_dtype
46

47

48
def get_autocast(precision):
49
    if precision == "amp":
50
        return torch.cuda.amp.autocast
51
    elif precision == "amp_bfloat16" or precision == "amp_bf16":
52
        # amp_bfloat16 is more stable than amp float16 for clip training
53
        return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
54
    elif precision == "fp16":
55
        return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
56
    else:
57
        return suppress
58

59

60
def get_checkpoint(model):
61
    state_dict = model.state_dict()
62

63
    for name, p in model.named_parameters():
64
        if not p.requires_grad:
65
            del state_dict[name]
66

67
    return state_dict
68

69

70
def get_checkpoint_deepspeed_zero3(args, model):
71
    state_dict = {}
72

73
    for name, p in model.named_parameters():
74
        if p.requires_grad:
75
            state_dict[name] = p.data
76
    return state_dict
77

78
    # if torch.distributed.get_rank() == 0:
79
    #     # 有参数
80
    #     print(device_id, f"IDEFICS Trainable Params: {(sum(p.numel() for p in model.parameters() if p.requires_grad)) / 1e9:.3f} B")
81

82

83
class AverageMeter(object):
84
    """Computes and stores the average and current value"""
85

86
    def __init__(self):
87
        self.reset()
88

89
    def reset(self):
90
        self.val = 0
91
        self.avg = 0
92
        self.sum = 0
93
        self.count = 0
94

95
    def update(self, val, n=1):
96
        self.val = val
97
        self.sum += val * n
98
        self.count += n
99
        self.avg = self.sum / self.count
100

101

102
class DistributedProxySampler(DistributedSampler):
103
    """Sampler that restricts data loading to a subset of input sampler indices.
104

105
    It is especially useful in conjunction with
106
    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
107
    process can pass a DistributedSampler instance as a DataLoader sampler,
108
    and load a subset of the original dataset that is exclusive to it.
109

110
    .. note::
111
        Input sampler is assumed to be of constant size.
112

113
    Arguments:
114
        sampler: Input data sampler.
115
        num_replicas (optional): Number of processes participating in
116
            distributed training.
117
        rank (optional): Rank of the current process within num_replicas.
118
    """
119

120
    def __init__(self, sampler, num_replicas=None, rank=None):
121
        super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
122
        self.sampler = sampler
123

124
    def __iter__(self):
125
        # deterministically shuffle based on epoch
126
        torch.manual_seed(self.epoch)
127
        indices = list(self.sampler)
128

129
        # add extra samples to make it evenly divisible
130
        indices += indices[: (self.total_size - len(indices))]
131
        if len(indices) != self.total_size:
132
            raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
133

134
        # subsample
135
        indices = indices[self.rank : self.total_size : self.num_replicas]
136
        if len(indices) != self.num_samples:
137
            raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))
138

139
        return iter(indices)
140

141

142
# supporting idefics processing
143
def get_image_attention_mask(output_input_ids, max_num_images, tokenizer, include_image=True):
144
    # image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, tokenizer)
145
    # image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
146
    if include_image:
147
        image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, tokenizer)
148
        image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
149
    else:
150
        # in full language mode we set the image mask to all-0s
151
        image_attention_mask = torch.zeros(output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool)
152
    return image_attention_mask
153

154

155
def verify_yaml(args):
156
    if args.rank != 0:
157
        return
158

159
    # Run pytest with the necessary arguments.
160
    result = subprocess.run(["pytest", "-m", "prerun", f"--yaml-path={args.training_data_yaml}"])
161

162
    if result.returncode != 0:
163
        print("YAML verification failed!")
164
        sys.exit(1)
165

166

167
def get_grouped_params(model, wd):
168
    params_with_wd, params_without_wd = [], []
169

170
    def apply_decay(x):
171
        return "gated_cross_attn_layer" in x and "ff_gate" not in x and "attn_gate" not in x and "norm" not in x and "bias" not in x
172

173
    for n, p in model.named_parameters():
174
        # if p.requires_grad:
175
        if apply_decay(n):
176
            params_with_wd.append(p)
177
        else:
178
            params_without_wd.append(p)
179

180
    return [
181
        {"params": params_with_wd, "weight_decay": wd},
182
        {"params": params_without_wd, "weight_decay": 0.0},
183
    ]
184

185

186
def save_checkpoint(epoch, model, args, accelerator, unwrapped_model=None, global_step=None):
187
    """Save a checkpoint for the model."""
188
    # Ensure the directory exists
189
    if not os.path.exists(args.external_save_dir):
190
        os.makedirs(args.external_save_dir)
191

192
    if unwrapped_model is None:
193
        unwrapped_model = accelerator.unwrap_model(model)
194

195
    # Formulate the checkpoint filename based on whether it's an epoch or global_step checkpoint
196
    if global_step:
197
        checkpoint_path = f"{args.external_save_dir}/checkpoint_steps_{global_step}.pt"
198
        checkpoint_dict = {
199
            "steps": global_step,
200
            "model_state_dict": get_checkpoint(unwrapped_model),
201
        }
202
    else:
203
        checkpoint_path = f"{args.external_save_dir}/checkpoint_{epoch}.pt"
204
        checkpoint_dict = {"model_state_dict": get_checkpoint(unwrapped_model)}
205

206
    # Save the checkpoint if rank is 0
207
    if args.rank == 0:
208
        print(f"Saving checkpoint to {checkpoint_path}")
209
        accelerator.save(checkpoint_dict, checkpoint_path)
210

211
        # Save the model's configuration
212
        unwrapped_model.config.save_pretrained(args.external_save_dir)
213

214
        # Remove the previous checkpoint if required
215
        if args.delete_previous_checkpoint:
216
            if global_step:
217
                prev_checkpoint_path = f"{args.external_save_dir}/checkpoint_step_{global_step-args.save_steps_interval}.pt"
218
                if os.path.exists(prev_checkpoint_path):
219
                    os.remove(prev_checkpoint_path)
220
            elif epoch > 0:
221
                os.remove(f"{args.external_save_dir}/checkpoint_{epoch-1}.pt")
222

223

224
def save_checkpoint(checkpoint_dict, save_path, is_main_process, save_function):
225
    """Helper function to save the checkpoint."""
226
    save_function(checkpoint_dict, f"{save_path}/final_weights.pt", is_main_process=is_main_process)
227

228

229
def save_pretrained(component, save_path, is_main_process, save_function):
230
    """Helper function to save pretrained components."""
231
    component.save_pretrained(save_path, is_main_process=is_main_process, save_function=save_function, safe_serialization=False)
232

233

234
def save_final_weights(model, args, accelerator, processor=None, tokenizer=None):
235
    """Save final weights of the model."""
236
    unwrapped_model = accelerator.unwrap_model(model)
237
    is_main_process = accelerator.is_main_process
238
    save_path = args.external_save_dir
239
    model_name = args.model_name.lower()
240

241
    unwrapped_model.config.save_pretrained(save_path)
242

243
    if args.save_hf_model:
244
        save_pretrained(unwrapped_model, save_path, is_main_process, accelerator.save)
245

246
        if "idefics" in model_name or "fuyu" in model_name:
247
            save_pretrained(processor, save_path, is_main_process, accelerator.save)
248

249
        if "llama2" in model_name:
250
            save_pretrained(tokenizer, save_path, is_main_process, accelerator.save)
251
    else:
252
        # Save based on the distributed type
253
        if accelerator.distributed_type == "DEEPSPEED" and accelerator.state.deepspeed_plugin.zero_stage == 3:
254
            checkpoint_dict = accelerator.get_state_dict(model)
255
        else:
256
            checkpoint_dict = get_checkpoint(model=unwrapped_model)
257

258
        if accelerator.distributed_type == "DEEPSPEED" and accelerator.state.deepspeed_plugin.zero_stage == 3:
259
            trainable_params_name = [name for name, p in unwrapped_model.named_parameters() if p.requires_grad]
260
            checkpoint_dict = {k: v for k, v in checkpoint_dict.items() if k in trainable_params_name}
261

262
        save_checkpoint(checkpoint_dict, save_path, is_main_process, accelerator.save)
263

264

265
def get_weights_for_dataloaders(dataloaders):
266
    total_samples = sum(len(dataloader.dataset) for dataloader in dataloaders)
267
    weights = [len(dataloader.dataset) / total_samples for dataloader in dataloaders]
268
    return weights
269

270

271
def get_next_dataloader(dataloader_iterators, weights):
272
    chosen_dataloader_index = np.random.choice(len(dataloader_iterators), p=weights)
273
    return dataloader_iterators[chosen_dataloader_index]
274

275

276
def find_and_remove_tokens(input_tensor, labels_tensor, attention_mask_tensor, token_id, tokenizer):
277
    batch_size, seq_len = input_tensor.size()
278

279
    # Create lists to store the new tensors
280
    new_input_list = []
281
    new_labels_list = []
282
    new_attention_mask_list = []
283

284
    # Loop over each sequence in the batch
285
    for i in range(batch_size):
286
        single_input = input_tensor[i, :]
287
        single_label = labels_tensor[i, :]
288
        single_attention_mask = attention_mask_tensor[i, :]
289

290
        # Remove the token_id
291
        new_single_input = torch.masked_select(single_input, single_input != token_id)
292
        new_single_label = torch.masked_select(single_label, single_input != token_id)
293
        new_single_attention_mask = torch.masked_select(single_attention_mask, single_input != token_id)
294

295
        # Append the new sequence to the list
296
        new_input_list.append(new_single_input)
297
        new_labels_list.append(new_single_label)
298
        new_attention_mask_list.append(new_single_attention_mask)
299

300
    # Pad sequences within each batch to match the longest sequence
301
    new_input = torch.nn.utils.rnn.pad_sequence(new_input_list, batch_first=True, padding_value=tokenizer.pad_token_id)
302
    new_labels = torch.nn.utils.rnn.pad_sequence(new_labels_list, batch_first=True, padding_value=-100)
303
    new_attention_mask = torch.nn.utils.rnn.pad_sequence(new_attention_mask_list, batch_first=True, padding_value=0)
304

305
    return new_input, new_labels, new_attention_mask
306

307

308
def delete_tensors_from_dict(d):
309
    """Recursively delete tensors from a nested dictionary."""
310
    keys_to_delete = []
311
    for k, v in d.items():
312
        if isinstance(v, torch.Tensor):
313
            keys_to_delete.append(k)
314
        elif isinstance(v, list):
315
            new_list = [item for item in v if not isinstance(item, torch.Tensor)]
316
            d[k] = new_list
317
        elif isinstance(v, dict):
318
            delete_tensors_from_dict(v)
319
    for key in keys_to_delete:
320
        del d[key]
321

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.