skypilot

Форк
0
354 строки · 12.4 Кб
1
# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
2
#
3
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
#
5
#    Licensed under the Apache License, Version 2.0 (the "License");
6
#    you may not use this file except in compliance with the License.
7
#    You may obtain a copy of the License at
8
#
9
#        http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#    Unless required by applicable law or agreed to in writing, software
12
#    distributed under the License is distributed on an "AS IS" BASIS,
13
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#    See the License for the specific language governing permissions and
15
#    limitations under the License.
16
#
17
# ==============================================================================
18
#
19
# The code was modified by the lmsys-org/FastChat authors, and following is the license:
20
#    Copyright 2023 FastChat authors
21
#    Licensed under the Apache License, Version 2.0 (the "License");
22
#    you may not use this file except in compliance with the License.
23
#    You may obtain a copy of the License at
24
#
25
#        http://www.apache.org/licenses/LICENSE-2.0
26
#
27
#    Unless required by applicable law or agreed to in writing, software
28
#    distributed under the License is distributed on an "AS IS" BASIS,
29
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30
#    See the License for the specific language governing permissions and
31
#    limitations under the License.
32

33
from dataclasses import dataclass
34
from dataclasses import field
35
import json
36
import os
37
import pathlib
38
import shutil
39
import subprocess
40
from typing import Dict, Optional
41

42
from fastchat.conversation import SeparatorStyle
43
from fastchat.model.model_adapter import get_conversation_template
44
import torch
45
from torch.utils.data import Dataset
46
import transformers
47
from transformers import Trainer
48
from transformers.trainer_pt_utils import LabelSmoother
49

50
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
51

52

53
@dataclass
54
class ModelArguments:
55
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
56

57

58
@dataclass
59
class DataArguments:
60
    data_path: str = field(default=None,
61
                           metadata={"help": "Path to the training data."})
62
    eval_data_path: str = field(
63
        default=None, metadata={"help": "Path to the evaluation data."})
64
    lazy_preprocess: bool = False
65

66

67
@dataclass
68
class TrainingArguments(transformers.TrainingArguments):
69
    cache_dir: Optional[str] = field(default=None)
70
    optim: str = field(default="adamw_torch")
71
    model_max_length: int = field(
72
        default=512,
73
        metadata={
74
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
75
        },
76
    )
77

78

79
local_rank = None
80

81

82
def rank0_print(*args):
83
    if local_rank == 0:
84
        print(*args)
85

86

87
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
88
                                   output_dir: str):
89
    """Collects the state dict and dump to disk."""
90
    state_dict = trainer.model.state_dict()
91
    if trainer.args.should_save:
92
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
93
        del state_dict
94
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa
95

96

97
def preprocess(
98
    sources,
99
    tokenizer: transformers.PreTrainedTokenizer,
100
) -> Dict:
101
    conv = get_conversation_template("vicuna")
102
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
103

104
    # Apply prompt templates
105
    conversations = []
106
    for i, source in enumerate(sources):
107
        if not source or source[0]["from"] not in roles:
108
            continue
109
        if roles[source[0]["from"]] != conv.roles[0]:
110
            # Skip the first one if it is not from human
111
            source = source[1:]
112

113
        conv.messages = []
114
        role_id = 0
115
        for sentence in source:
116
            if sentence["from"] not in roles:
117
                print(f"Skip unknown role {sentence['from']!r}")
118
                continue
119
            role = roles[sentence["from"]]
120
            if role != conv.roles[role_id % 2]:
121
                print(f"Skip duplicated role {role!r}")
122
                continue
123
            role_id += 1
124
            conv.append_message(role, sentence["value"])
125
        else:
126
            conversations.append(conv.get_prompt())
127
    if not conversations:
128
        conv.append_message(conv.roles[0], '')
129
        conv.append_message(conv.roles[1], '')
130
        conversations.append(conv.get_prompt())
131

132
    # Tokenize conversations
133
    input_ids = tokenizer(
134
        conversations,
135
        return_tensors="pt",
136
        padding="max_length",
137
        max_length=tokenizer.model_max_length,
138
        truncation=True,
139
    ).input_ids
140
    targets = input_ids.clone()
141

142
    assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
143

144
    # Mask targets. Only compute loss on the assistant outputs.
145
    sep = conv.sep + conv.roles[1] + ": "
146
    for conversation, target in zip(conversations, targets):
147
        total_len = int(target.ne(tokenizer.pad_token_id).sum())
148

149
        turns = conversation.split(conv.sep2)
150
        cur_len = 1
151
        target[:cur_len] = IGNORE_TOKEN_ID
152
        for i, turn in enumerate(turns):
153
            if turn == "":
154
                break
155
            turn_len = len(tokenizer(turn).input_ids)
156

157
            parts = turn.split(sep)
158
            if len(parts) != 2:
159
                break
160
            parts[0] += sep
161
            # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
162
            instruction_len = len(tokenizer(parts[0]).input_ids) - 2
163

164
            # Ignore the user instructions
165
            target[cur_len:cur_len + instruction_len] = IGNORE_TOKEN_ID
166
            cur_len += turn_len
167

168
        target[cur_len:] = IGNORE_TOKEN_ID
169

170
        if False:  # Inspect and check the correctness of masking
171
            z = target.clone()
172
            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
173
            rank0_print(tokenizer.decode(z))
174

175
        if cur_len < tokenizer.model_max_length:
176
            if cur_len != total_len:
177
                target[:] = IGNORE_TOKEN_ID
178
                rank0_print(
179
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
180
                    f" (ignored)")
181

182
    return dict(
183
        input_ids=input_ids,
184
        labels=targets,
185
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
186
    )
187

188

189
class SupervisedDataset(Dataset):
190
    """Dataset for supervised fine-tuning."""
191

192
    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
193
        super(SupervisedDataset, self).__init__()
194

195
        rank0_print("Formatting inputs...")
196
        sources = [example["conversations"] for example in raw_data]
197
        data_dict = preprocess(sources, tokenizer)
198

199
        self.input_ids = data_dict["input_ids"]
200
        self.labels = data_dict["labels"]
201
        self.attention_mask = data_dict["attention_mask"]
202

203
    def __len__(self):
204
        return len(self.input_ids)
205

206
    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
207
        return dict(
208
            input_ids=self.input_ids[i],
209
            labels=self.labels[i],
210
            attention_mask=self.attention_mask[i],
211
        )
212

213

214
class LazySupervisedDataset(Dataset):
215
    """Dataset for supervised fine-tuning."""
216

217
    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
218
        super(LazySupervisedDataset, self).__init__()
219
        self.tokenizer = tokenizer
220

221
        rank0_print("Formatting inputs...Skip in lazy mode")
222
        self.tokenizer = tokenizer
223
        self.raw_data = raw_data
224
        self.cached_data_dict = {}
225

226
    def __len__(self):
227
        return len(self.raw_data)
228

229
    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
230
        if i in self.cached_data_dict:
231
            return self.cached_data_dict[i]
232

233
        ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
234
        ret = dict(
235
            input_ids=ret["input_ids"][0],
236
            labels=ret["labels"][0],
237
            attention_mask=ret["attention_mask"][0],
238
        )
239
        self.cached_data_dict[i] = ret
240

241
        return ret
242

243

244
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
245
                                data_args) -> Dict:
246
    """Make dataset and collator for supervised fine-tuning."""
247
    dataset_cls = (LazySupervisedDataset
248
                   if data_args.lazy_preprocess else SupervisedDataset)
249
    rank0_print("Loading data...")
250

251
    train_json = json.load(open(data_args.data_path, "r"))
252
    train_dataset = dataset_cls(train_json, tokenizer=tokenizer)
253

254
    if data_args.eval_data_path:
255
        eval_json = json.load(open(data_args.eval_data_path, "r"))
256
        eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
257
    else:
258
        eval_dataset = None
259

260
    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
261

262

263
class CheckpointCallback(transformers.TrainerCallback):
264

265
    def on_save(self, args, state, control, **kwargs):
266
        """Add complete indicator to avoid incomplete checkpoints."""
267
        if state.is_world_process_zero:
268
            ckpt_path = os.path.join(args.output_dir,
269
                                     f'checkpoint-{state.global_step}')
270
            with open(os.path.join(ckpt_path, 'complete'), 'w') as f:
271
                f.write('')
272
            print(f'Checkpoint {state.global_step} saved.')
273
        torch.distributed.barrier()
274

275

276
def cleanup_incomplete_checkpoints(output_dir):
277
    """Remove incomplete checkpoints."""
278
    checkpoints = list(pathlib.Path(output_dir).glob('checkpoint-*'))
279
    checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
280
    checkpoints = sorted(checkpoints,
281
                         key=lambda x: int(x.name.split('-')[-1]),
282
                         reverse=True)
283
    for checkpoint in checkpoints:
284
        if not (checkpoint / 'complete').exists():
285
            print(f'Removing incomplete checkpoint {checkpoint}')
286
            shutil.rmtree(checkpoint)
287
        else:
288
            print(f'Using checkpoint {checkpoint}, copying to ~/tmp/ for '
289
                  'optimization of loading.')
290
            tmp_dir = os.path.expanduser('~/tmp')
291
            os.makedirs(tmp_dir, exist_ok=True)
292
            try:
293
                # Optimization for checkpoint loading. This is to force the
294
                # mounting tool to download the checkpoints in parallel first.
295
                # It will improve the loading speed of the checkpoints
296
                # significantly.
297
                subprocess.run(
298
                    ['gsutil', '-m', 'rsync', '-r', checkpoint, tmp_dir],
299
                    check=True)
300
            except:
301
                print('Failed to optimize checkpoint loading. Skip.')
302
            break
303

304

305
def train():
306
    global local_rank
307

308
    parser = transformers.HfArgumentParser(
309
        (ModelArguments, DataArguments, TrainingArguments))
310
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
311
    local_rank = training_args.local_rank
312
    if local_rank == 0:
313
        cleanup_incomplete_checkpoints(training_args.output_dir)
314
    torch.distributed.barrier()
315

316
    # Check the existence of checkpoints in all processes
317
    # All ranks must simultaneously resume from a checkpoint if it exists.
318
    # Otherwise, upon recovery the model weights may not reload correctly,
319
    # causing loss spikes.
320
    resume_from_checkpoint = False
321
    checkpoints = list(
322
        pathlib.Path(training_args.output_dir).glob('checkpoint-*'))
323
    checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
324
    if checkpoints:
325
        resume_from_checkpoint = True
326
    model = transformers.AutoModelForCausalLM.from_pretrained(
327
        model_args.model_name_or_path,
328
        cache_dir=training_args.cache_dir,
329
    )
330
    model.config.use_cache = False
331
    tokenizer = transformers.AutoTokenizer.from_pretrained(
332
        model_args.model_name_or_path,
333
        cache_dir=training_args.cache_dir,
334
        model_max_length=training_args.model_max_length,
335
        padding_side="right",
336
        use_fast=False,
337
    )
338
    tokenizer.pad_token = tokenizer.unk_token
339

340
    data_module = make_supervised_data_module(tokenizer=tokenizer,
341
                                              data_args=data_args)
342
    trainer = Trainer(model=model,
343
                      tokenizer=tokenizer,
344
                      args=training_args,
345
                      **data_module)
346
    trainer.add_callback(CheckpointCallback)
347
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
348
    trainer.save_state()
349
    safe_save_model_for_hf_trainer(trainer=trainer,
350
                                   output_dir=training_args.output_dir)
351

352

353
if __name__ == "__main__":
354
    train()
355

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.