skypilot
354 строки · 12.4 Кб
1# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
2#
3# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# ==============================================================================
18#
19# The code was modified by the lmsys-org/FastChat authors, and following is the license:
20# Copyright 2023 FastChat authors
21# Licensed under the Apache License, Version 2.0 (the "License");
22# you may not use this file except in compliance with the License.
23# You may obtain a copy of the License at
24#
25# http://www.apache.org/licenses/LICENSE-2.0
26#
27# Unless required by applicable law or agreed to in writing, software
28# distributed under the License is distributed on an "AS IS" BASIS,
29# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30# See the License for the specific language governing permissions and
31# limitations under the License.
32
33from dataclasses import dataclass
34from dataclasses import field
35import json
36import os
37import pathlib
38import shutil
39import subprocess
40from typing import Dict, Optional
41
42from fastchat.conversation import SeparatorStyle
43from fastchat.model.model_adapter import get_conversation_template
44import torch
45from torch.utils.data import Dataset
46import transformers
47from transformers import Trainer
48from transformers.trainer_pt_utils import LabelSmoother
49
50IGNORE_TOKEN_ID = LabelSmoother.ignore_index
51
52
53@dataclass
54class ModelArguments:
55model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
56
57
58@dataclass
59class DataArguments:
60data_path: str = field(default=None,
61metadata={"help": "Path to the training data."})
62eval_data_path: str = field(
63default=None, metadata={"help": "Path to the evaluation data."})
64lazy_preprocess: bool = False
65
66
67@dataclass
68class TrainingArguments(transformers.TrainingArguments):
69cache_dir: Optional[str] = field(default=None)
70optim: str = field(default="adamw_torch")
71model_max_length: int = field(
72default=512,
73metadata={
74"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
75},
76)
77
78
79local_rank = None
80
81
82def rank0_print(*args):
83if local_rank == 0:
84print(*args)
85
86
87def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
88output_dir: str):
89"""Collects the state dict and dump to disk."""
90state_dict = trainer.model.state_dict()
91if trainer.args.should_save:
92cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
93del state_dict
94trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
95
96
97def preprocess(
98sources,
99tokenizer: transformers.PreTrainedTokenizer,
100) -> Dict:
101conv = get_conversation_template("vicuna")
102roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
103
104# Apply prompt templates
105conversations = []
106for i, source in enumerate(sources):
107if not source or source[0]["from"] not in roles:
108continue
109if roles[source[0]["from"]] != conv.roles[0]:
110# Skip the first one if it is not from human
111source = source[1:]
112
113conv.messages = []
114role_id = 0
115for sentence in source:
116if sentence["from"] not in roles:
117print(f"Skip unknown role {sentence['from']!r}")
118continue
119role = roles[sentence["from"]]
120if role != conv.roles[role_id % 2]:
121print(f"Skip duplicated role {role!r}")
122continue
123role_id += 1
124conv.append_message(role, sentence["value"])
125else:
126conversations.append(conv.get_prompt())
127if not conversations:
128conv.append_message(conv.roles[0], '')
129conv.append_message(conv.roles[1], '')
130conversations.append(conv.get_prompt())
131
132# Tokenize conversations
133input_ids = tokenizer(
134conversations,
135return_tensors="pt",
136padding="max_length",
137max_length=tokenizer.model_max_length,
138truncation=True,
139).input_ids
140targets = input_ids.clone()
141
142assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
143
144# Mask targets. Only compute loss on the assistant outputs.
145sep = conv.sep + conv.roles[1] + ": "
146for conversation, target in zip(conversations, targets):
147total_len = int(target.ne(tokenizer.pad_token_id).sum())
148
149turns = conversation.split(conv.sep2)
150cur_len = 1
151target[:cur_len] = IGNORE_TOKEN_ID
152for i, turn in enumerate(turns):
153if turn == "":
154break
155turn_len = len(tokenizer(turn).input_ids)
156
157parts = turn.split(sep)
158if len(parts) != 2:
159break
160parts[0] += sep
161# "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
162instruction_len = len(tokenizer(parts[0]).input_ids) - 2
163
164# Ignore the user instructions
165target[cur_len:cur_len + instruction_len] = IGNORE_TOKEN_ID
166cur_len += turn_len
167
168target[cur_len:] = IGNORE_TOKEN_ID
169
170if False: # Inspect and check the correctness of masking
171z = target.clone()
172z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
173rank0_print(tokenizer.decode(z))
174
175if cur_len < tokenizer.model_max_length:
176if cur_len != total_len:
177target[:] = IGNORE_TOKEN_ID
178rank0_print(
179f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
180f" (ignored)")
181
182return dict(
183input_ids=input_ids,
184labels=targets,
185attention_mask=input_ids.ne(tokenizer.pad_token_id),
186)
187
188
189class SupervisedDataset(Dataset):
190"""Dataset for supervised fine-tuning."""
191
192def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
193super(SupervisedDataset, self).__init__()
194
195rank0_print("Formatting inputs...")
196sources = [example["conversations"] for example in raw_data]
197data_dict = preprocess(sources, tokenizer)
198
199self.input_ids = data_dict["input_ids"]
200self.labels = data_dict["labels"]
201self.attention_mask = data_dict["attention_mask"]
202
203def __len__(self):
204return len(self.input_ids)
205
206def __getitem__(self, i) -> Dict[str, torch.Tensor]:
207return dict(
208input_ids=self.input_ids[i],
209labels=self.labels[i],
210attention_mask=self.attention_mask[i],
211)
212
213
214class LazySupervisedDataset(Dataset):
215"""Dataset for supervised fine-tuning."""
216
217def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
218super(LazySupervisedDataset, self).__init__()
219self.tokenizer = tokenizer
220
221rank0_print("Formatting inputs...Skip in lazy mode")
222self.tokenizer = tokenizer
223self.raw_data = raw_data
224self.cached_data_dict = {}
225
226def __len__(self):
227return len(self.raw_data)
228
229def __getitem__(self, i) -> Dict[str, torch.Tensor]:
230if i in self.cached_data_dict:
231return self.cached_data_dict[i]
232
233ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
234ret = dict(
235input_ids=ret["input_ids"][0],
236labels=ret["labels"][0],
237attention_mask=ret["attention_mask"][0],
238)
239self.cached_data_dict[i] = ret
240
241return ret
242
243
244def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
245data_args) -> Dict:
246"""Make dataset and collator for supervised fine-tuning."""
247dataset_cls = (LazySupervisedDataset
248if data_args.lazy_preprocess else SupervisedDataset)
249rank0_print("Loading data...")
250
251train_json = json.load(open(data_args.data_path, "r"))
252train_dataset = dataset_cls(train_json, tokenizer=tokenizer)
253
254if data_args.eval_data_path:
255eval_json = json.load(open(data_args.eval_data_path, "r"))
256eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
257else:
258eval_dataset = None
259
260return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
261
262
263class CheckpointCallback(transformers.TrainerCallback):
264
265def on_save(self, args, state, control, **kwargs):
266"""Add complete indicator to avoid incomplete checkpoints."""
267if state.is_world_process_zero:
268ckpt_path = os.path.join(args.output_dir,
269f'checkpoint-{state.global_step}')
270with open(os.path.join(ckpt_path, 'complete'), 'w') as f:
271f.write('')
272print(f'Checkpoint {state.global_step} saved.')
273torch.distributed.barrier()
274
275
276def cleanup_incomplete_checkpoints(output_dir):
277"""Remove incomplete checkpoints."""
278checkpoints = list(pathlib.Path(output_dir).glob('checkpoint-*'))
279checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
280checkpoints = sorted(checkpoints,
281key=lambda x: int(x.name.split('-')[-1]),
282reverse=True)
283for checkpoint in checkpoints:
284if not (checkpoint / 'complete').exists():
285print(f'Removing incomplete checkpoint {checkpoint}')
286shutil.rmtree(checkpoint)
287else:
288print(f'Using checkpoint {checkpoint}, copying to ~/tmp/ for '
289'optimization of loading.')
290tmp_dir = os.path.expanduser('~/tmp')
291os.makedirs(tmp_dir, exist_ok=True)
292try:
293# Optimization for checkpoint loading. This is to force the
294# mounting tool to download the checkpoints in parallel first.
295# It will improve the loading speed of the checkpoints
296# significantly.
297subprocess.run(
298['gsutil', '-m', 'rsync', '-r', checkpoint, tmp_dir],
299check=True)
300except:
301print('Failed to optimize checkpoint loading. Skip.')
302break
303
304
305def train():
306global local_rank
307
308parser = transformers.HfArgumentParser(
309(ModelArguments, DataArguments, TrainingArguments))
310model_args, data_args, training_args = parser.parse_args_into_dataclasses()
311local_rank = training_args.local_rank
312if local_rank == 0:
313cleanup_incomplete_checkpoints(training_args.output_dir)
314torch.distributed.barrier()
315
316# Check the existence of checkpoints in all processes
317# All ranks must simultaneously resume from a checkpoint if it exists.
318# Otherwise, upon recovery the model weights may not reload correctly,
319# causing loss spikes.
320resume_from_checkpoint = False
321checkpoints = list(
322pathlib.Path(training_args.output_dir).glob('checkpoint-*'))
323checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
324if checkpoints:
325resume_from_checkpoint = True
326model = transformers.AutoModelForCausalLM.from_pretrained(
327model_args.model_name_or_path,
328cache_dir=training_args.cache_dir,
329)
330model.config.use_cache = False
331tokenizer = transformers.AutoTokenizer.from_pretrained(
332model_args.model_name_or_path,
333cache_dir=training_args.cache_dir,
334model_max_length=training_args.model_max_length,
335padding_side="right",
336use_fast=False,
337)
338tokenizer.pad_token = tokenizer.unk_token
339
340data_module = make_supervised_data_module(tokenizer=tokenizer,
341data_args=data_args)
342trainer = Trainer(model=model,
343tokenizer=tokenizer,
344args=training_args,
345**data_module)
346trainer.add_callback(CheckpointCallback)
347trainer.train(resume_from_checkpoint=resume_from_checkpoint)
348trainer.save_state()
349safe_save_model_for_hf_trainer(trainer=trainer,
350output_dir=training_args.output_dir)
351
352
353if __name__ == "__main__":
354train()
355