stanford_alpaca
/
train.py
222 строки · 8.1 Кб
1# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import copy
16import logging
17from dataclasses import dataclass, field
18from typing import Dict, Optional, Sequence
19
20import torch
21import transformers
22import utils
23from torch.utils.data import Dataset
24from transformers import Trainer
25
26IGNORE_INDEX = -100
27DEFAULT_PAD_TOKEN = "[PAD]"
28DEFAULT_EOS_TOKEN = "</s>"
29DEFAULT_BOS_TOKEN = "<s>"
30DEFAULT_UNK_TOKEN = "<unk>"
31PROMPT_DICT = {
32"prompt_input": (
33"Below is an instruction that describes a task, paired with an input that provides further context. "
34"Write a response that appropriately completes the request.\n\n"
35"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
36),
37"prompt_no_input": (
38"Below is an instruction that describes a task. "
39"Write a response that appropriately completes the request.\n\n"
40"### Instruction:\n{instruction}\n\n### Response:"
41),
42}
43
44
45@dataclass
46class ModelArguments:
47model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
48
49
50@dataclass
51class DataArguments:
52data_path: str = field(default=None, metadata={"help": "Path to the training data."})
53
54
55@dataclass
56class TrainingArguments(transformers.TrainingArguments):
57cache_dir: Optional[str] = field(default=None)
58optim: str = field(default="adamw_torch")
59model_max_length: int = field(
60default=512,
61metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
62)
63
64
65def smart_tokenizer_and_embedding_resize(
66special_tokens_dict: Dict,
67tokenizer: transformers.PreTrainedTokenizer,
68model: transformers.PreTrainedModel,
69):
70"""Resize tokenizer and embedding.
71
72Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
73"""
74num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
75model.resize_token_embeddings(len(tokenizer))
76
77if num_new_tokens > 0:
78input_embeddings = model.get_input_embeddings().weight.data
79output_embeddings = model.get_output_embeddings().weight.data
80
81input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
82output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
83
84input_embeddings[-num_new_tokens:] = input_embeddings_avg
85output_embeddings[-num_new_tokens:] = output_embeddings_avg
86
87
88def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
89"""Tokenize a list of strings."""
90tokenized_list = [
91tokenizer(
92text,
93return_tensors="pt",
94padding="longest",
95max_length=tokenizer.model_max_length,
96truncation=True,
97)
98for text in strings
99]
100input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
101input_ids_lens = labels_lens = [
102tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
103]
104return dict(
105input_ids=input_ids,
106labels=labels,
107input_ids_lens=input_ids_lens,
108labels_lens=labels_lens,
109)
110
111
112def preprocess(
113sources: Sequence[str],
114targets: Sequence[str],
115tokenizer: transformers.PreTrainedTokenizer,
116) -> Dict:
117"""Preprocess the data by tokenizing."""
118examples = [s + t for s, t in zip(sources, targets)]
119examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
120input_ids = examples_tokenized["input_ids"]
121labels = copy.deepcopy(input_ids)
122for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
123label[:source_len] = IGNORE_INDEX
124return dict(input_ids=input_ids, labels=labels)
125
126
127class SupervisedDataset(Dataset):
128"""Dataset for supervised fine-tuning."""
129
130def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
131super(SupervisedDataset, self).__init__()
132logging.warning("Loading data...")
133list_data_dict = utils.jload(data_path)
134
135logging.warning("Formatting inputs...")
136prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
137sources = [
138prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
139for example in list_data_dict
140]
141targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
142
143logging.warning("Tokenizing inputs... This may take some time...")
144data_dict = preprocess(sources, targets, tokenizer)
145
146self.input_ids = data_dict["input_ids"]
147self.labels = data_dict["labels"]
148
149def __len__(self):
150return len(self.input_ids)
151
152def __getitem__(self, i) -> Dict[str, torch.Tensor]:
153return dict(input_ids=self.input_ids[i], labels=self.labels[i])
154
155
156@dataclass
157class DataCollatorForSupervisedDataset(object):
158"""Collate examples for supervised fine-tuning."""
159
160tokenizer: transformers.PreTrainedTokenizer
161
162def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
163input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
164input_ids = torch.nn.utils.rnn.pad_sequence(
165input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
166)
167labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
168return dict(
169input_ids=input_ids,
170labels=labels,
171attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
172)
173
174
175def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
176"""Make dataset and collator for supervised fine-tuning."""
177train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
178data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
179return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
180
181
182def train():
183parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
184model_args, data_args, training_args = parser.parse_args_into_dataclasses()
185
186model = transformers.AutoModelForCausalLM.from_pretrained(
187model_args.model_name_or_path,
188cache_dir=training_args.cache_dir,
189)
190
191tokenizer = transformers.AutoTokenizer.from_pretrained(
192model_args.model_name_or_path,
193cache_dir=training_args.cache_dir,
194model_max_length=training_args.model_max_length,
195padding_side="right",
196use_fast=False,
197)
198special_tokens_dict = dict()
199if tokenizer.pad_token is None:
200special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
201if tokenizer.eos_token is None:
202special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
203if tokenizer.bos_token is None:
204special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
205if tokenizer.unk_token is None:
206special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
207
208smart_tokenizer_and_embedding_resize(
209special_tokens_dict=special_tokens_dict,
210tokenizer=tokenizer,
211model=model,
212)
213
214data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
215trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
216trainer.train()
217trainer.save_state()
218trainer.save_model(output_dir=training_args.output_dir)
219
220
221if __name__ == "__main__":
222train()
223