lmops

Форк
0
/
utils.py 
243 строки · 8.3 Кб
1
from typing import Dict
2
import numpy as np
3
import os
4
import time
5
import torch.distributed as dist
6
from torch.distributed import get_rank
7
import random
8
import torch
9
import torch.nn as nn
10
from datetime import timedelta
11
import deepspeed
12
from accelerate import load_checkpoint_and_dispatch, init_empty_weights
13
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
14

15

16
from transformers import (
17
    AutoModelForCausalLM,
18
    AutoTokenizer,
19
    AutoConfig,
20
    ParallelOPTForCausalLM,
21
    ParallelLlamaForCausalLM,
22
    ParallelGPTJForCausalLM,
23
    ParallelGPT2LMHeadModel,
24
    ParallelMistralForCausalLM,
25
    ParallelQWenLMHeadModel,
26
    mpu,)
27

28

29
parallel_model_map = {
30
    "opt": ParallelOPTForCausalLM,
31
    "gptj": ParallelGPTJForCausalLM,
32
    "gpt2": ParallelGPT2LMHeadModel,
33
    "llama": ParallelLlamaForCausalLM,
34
    "llama2": ParallelLlamaForCausalLM,
35
    "mistral": ParallelMistralForCausalLM,
36
    "qwen": ParallelQWenLMHeadModel,
37
}
38

39

40
# Logging
41
def print_args(args):
42
    """Print arguments."""
43

44
    print('arguments:', flush=True)
45
    for arg in vars(args):
46
        dots = '.' * (29 - len(arg))
47
        print('  {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
48

49

50
def save_rank(log_str, save_path, rank=0):
51
    if not dist.is_initialized() or dist.get_rank() == rank:
52
        with open(save_path, "a") as f:
53
            f.write(log_str + "\n")
54

55

56
def print_rank(*args, rank=0, **kwargs):
57
    if not dist.is_initialized() or dist.get_rank() == rank:
58
        print(*args, **kwargs)
59

60

61
# Distributed
62
def all_gather(t, dim=0, world_size=None, group=None, op="cat"):
63
    if world_size is None:
64
        world_size = dist.get_world_size()
65
    all_t = [torch.zeros_like(t) for _ in range(world_size)]
66
    dist.all_gather(all_t, t, group=group)
67
    if op == "cat":
68
        all_t = torch.cat(all_t, dim=dim)
69
    elif op == "stack":
70
        all_t = torch.stack(all_t, dim=dim)
71
    return all_t
72

73

74
# Initialize
75
def set_random_seed(seed, mp=False):
76
    """Set random seed for reproducability."""
77
    seed = dist.get_rank() + seed
78
    if seed is not None and seed > 0:
79
        random.seed(seed)
80
        np.random.seed(seed)
81
        torch.manual_seed(seed)
82
        if mp:
83
            mpu.model_parallel_cuda_manual_seed(seed)
84

85

86
def init_distributed(args):
87
    args.rank = int(os.getenv("RANK", "0"))
88
    args.world_size = int(os.getenv("WORLD_SIZE", "1"))
89
    args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
90

91
    if args.rank == 0:
92
        print(f"using world size: {args.world_size}")
93

94
    # Manually set the device ids.
95
    device = args.rank % torch.cuda.device_count()
96
    if args.local_rank is not None:
97
        device = args.local_rank
98
    torch.cuda.set_device(device)
99

100
    dist.init_process_group(backend="nccl", timeout=timedelta(minutes=300))
101

102

103
def init_distributed_ds(args):
104
    args.rank = int(os.getenv("RANK", "0"))
105
    args.world_size = int(os.getenv("WORLD_SIZE", "1"))
106
    args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
107

108
    if args.rank == 0:
109
        print(f"using world size: {args.world_size}")
110

111
    # Manually set the device ids.
112
    device = args.rank % torch.cuda.device_count()
113
    if args.local_rank is not None:
114
        device = args.local_rank
115
    torch.cuda.set_device(device)
116

117
    deepspeed.init_distributed(timeout=timedelta(minutes=300))
118

119

120
def initialize(args):
121
    # init bmt
122
    if args.deepspeed:
123
        init_distributed_ds(args)
124
    else:
125
        init_distributed(args)
126

127
    if args.model_parallel:
128
        assert dist.get_world_size() % args.model_parallel_size == 0 
129
        mpu.initialize_model_parallel(args.model_parallel_size)
130

131
    set_random_seed(args.seed, args.model_parallel)
132
    # init save folder
133
    if args.save != None:
134
        os.makedirs(args.save, exist_ok=True)
135

136

137
# Load and save model
138
def get_model(args, device):
139
    config = AutoConfig.from_pretrained(args.model_path)
140
    
141
    st_time = time.time()
142
    if args.model_parallel:
143
        config.is_model_parallel = True
144
        with init_empty_weights():
145
            if args.model_type=="qwen":
146
                model = parallel_model_map[args.model_type](config).to(torch.bfloat16)
147
            else:
148
                model = parallel_model_map[args.model_type](config).half()
149
        load_parallel(model, args.model_path)
150

151
        if mpu.get_data_parallel_rank() == 0:
152
            print(' > number of parameters on model parallel rank {}: {}'.format(
153
                mpu.get_model_parallel_rank(),
154
                sum([p.nelement() for p in model.parameters()])), flush=True)
155
    else:
156
        config.is_model_parallel = False
157
        if args.model_type=="qwen":
158
            dtype = torch.float32 if args.fp32 else torch.float16
159
        else:
160
            dtype = torch.float32 if args.fp32 else torch.bfloat16
161
        model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, device_map={"": device}, torch_dtype=dtype)
162

163
        if args.peft is not None:
164
            if args.peft == "lora":
165
                model.enable_input_require_grads()
166
                if args.peft_path is not None:
167
                    model = PeftModel.from_pretrained(model, args.peft_path)
168
                else:
169
                    peft_config = LoraConfig(
170
                        task_type=TaskType.CAUSAL_LM, inference_mode=(not args.do_train), r=args.peft_lora_r, lora_alpha=args.peft_lora_alpha, lora_dropout=args.peft_lora_dropout
171
                    )
172
                    model = get_peft_model(model, peft_config)
173
                model.print_trainable_parameters()
174
            else:
175
                raise NotImplementedError
176
        else:
177
            if dist.get_rank() == 0:
178
                print(' > number of parameters: {}'.format(
179
                    sum([p.nelement() for p in model.parameters()])), flush=True)
180
        # model = DDP(model)
181
        # NOTE: no need for DDP since deepspeed has done
182
    if args.gradient_checkpointing:
183
        model.gradient_checkpointing_enable()
184
    
185
    ed_time = time.time()
186
    
187
    print_rank(f"Model load time: {ed_time - st_time}s")
188
    
189
    return model
190

191

192
def get_optimizer_params(args, model: nn.Module):
193
    # taken from https://github.com/facebookresearch/SpanBERT/blob/0670d8b6a38f6714b85ea7a033f16bd8cc162676/code/run_tacred.py
194
    param_optimizer = list(model.named_parameters())
195
    no_decay = ['bias', 'ln_f.weight', 'ln_1.weight', 'ln_2.weight', 'ln_cross_attn']
196
    optimizer_grouped_parameters = [
197
        {'params': [p for n, p in param_optimizer
198
                    if not any(nd in n for nd in no_decay)]},
199
        {'params': [p for n, p in param_optimizer
200
                    if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
201
    ]
202

203
    return optimizer_grouped_parameters
204

205

206
def get_optimizer_params_peft(args, model: nn.Module):
207
    # taken from https://github.com/facebookresearch/SpanBERT/blob/0670d8b6a38f6714b85ea7a033f16bd8cc162676/code/run_tacred.py
208
    param_optimizer = list(model.named_parameters())
209
    optimizer_grouped_parameters = [
210
        {'params': [p for n, p in param_optimizer if p.requires_grad]},
211
    ]
212

213
    return optimizer_grouped_parameters
214

215

216
def get_tokenizer(args):
217
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
218
    if args.model_type in ["gpt2", "opt", "llama", "gptj", "llama2", "mistral"]:
219
        tokenizer.pad_token_id = tokenizer.eos_token_id
220
    elif args.model_type=="qwen":
221
        tokenizer.pad_token_id = 151646
222
        tokenizer.eos_token_id = 151643
223
        tokenizer.pad_token_id = tokenizer.eos_token_id
224
    
225
    return tokenizer
226

227

228
def load_parallel(model, load_dir):
229
    mp_rank = mpu.get_model_parallel_rank()
230
    assert mpu.get_model_parallel_world_size() != 1
231
    checkpoint_name = os.path.join(load_dir, f"mp{mpu.get_model_parallel_world_size()}", f"pytorch_model_{mp_rank}.bin")
232
    assert os.path.exists(checkpoint_name), f"{checkpoint_name} does not exist."
233
    model = load_checkpoint_and_dispatch(model=model, checkpoint=checkpoint_name, device_map={"": torch.cuda.current_device()}, dtype=torch.float16)
234
    dist.barrier()
235
    print(f"Rank {get_rank()}: {checkpoint_name} loaded.")
236

237

238
def save_parallel(model, save_dir):
239
    mp_rank = mpu.get_model_parallel_rank()
240
    os.makedirs(os.path.join(save_dir, f"mp{mpu.get_model_parallel_world_size()}"), exist_ok=True)
241
    checkpoint_name = os.path.join(save_dir, f"mp{mpu.get_model_parallel_world_size()}", f"pytorch_model_{mp_rank}.bin")
242
    torch.save(model.state_dict(), checkpoint_name)
243
    print(f"Rank {get_rank()}: {checkpoint_name} saved.")
244

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.