stanford_alpaca

Форк
0
/
train.py 
222 строки · 8.1 Кб
1
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
#
3
#    Licensed under the Apache License, Version 2.0 (the "License");
4
#    you may not use this file except in compliance with the License.
5
#    You may obtain a copy of the License at
6
#
7
#        http://www.apache.org/licenses/LICENSE-2.0
8
#
9
#    Unless required by applicable law or agreed to in writing, software
10
#    distributed under the License is distributed on an "AS IS" BASIS,
11
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
#    See the License for the specific language governing permissions and
13
#    limitations under the License.
14

15
import copy
16
import logging
17
from dataclasses import dataclass, field
18
from typing import Dict, Optional, Sequence
19

20
import torch
21
import transformers
22
import utils
23
from torch.utils.data import Dataset
24
from transformers import Trainer
25

26
IGNORE_INDEX = -100
27
DEFAULT_PAD_TOKEN = "[PAD]"
28
DEFAULT_EOS_TOKEN = "</s>"
29
DEFAULT_BOS_TOKEN = "<s>"
30
DEFAULT_UNK_TOKEN = "<unk>"
31
PROMPT_DICT = {
32
    "prompt_input": (
33
        "Below is an instruction that describes a task, paired with an input that provides further context. "
34
        "Write a response that appropriately completes the request.\n\n"
35
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
36
    ),
37
    "prompt_no_input": (
38
        "Below is an instruction that describes a task. "
39
        "Write a response that appropriately completes the request.\n\n"
40
        "### Instruction:\n{instruction}\n\n### Response:"
41
    ),
42
}
43

44

45
@dataclass
46
class ModelArguments:
47
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
48

49

50
@dataclass
51
class DataArguments:
52
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
53

54

55
@dataclass
56
class TrainingArguments(transformers.TrainingArguments):
57
    cache_dir: Optional[str] = field(default=None)
58
    optim: str = field(default="adamw_torch")
59
    model_max_length: int = field(
60
        default=512,
61
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
62
    )
63

64

65
def smart_tokenizer_and_embedding_resize(
66
    special_tokens_dict: Dict,
67
    tokenizer: transformers.PreTrainedTokenizer,
68
    model: transformers.PreTrainedModel,
69
):
70
    """Resize tokenizer and embedding.
71

72
    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
73
    """
74
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
75
    model.resize_token_embeddings(len(tokenizer))
76

77
    if num_new_tokens > 0:
78
        input_embeddings = model.get_input_embeddings().weight.data
79
        output_embeddings = model.get_output_embeddings().weight.data
80

81
        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
82
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
83

84
        input_embeddings[-num_new_tokens:] = input_embeddings_avg
85
        output_embeddings[-num_new_tokens:] = output_embeddings_avg
86

87

88
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
89
    """Tokenize a list of strings."""
90
    tokenized_list = [
91
        tokenizer(
92
            text,
93
            return_tensors="pt",
94
            padding="longest",
95
            max_length=tokenizer.model_max_length,
96
            truncation=True,
97
        )
98
        for text in strings
99
    ]
100
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
101
    input_ids_lens = labels_lens = [
102
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
103
    ]
104
    return dict(
105
        input_ids=input_ids,
106
        labels=labels,
107
        input_ids_lens=input_ids_lens,
108
        labels_lens=labels_lens,
109
    )
110

111

112
def preprocess(
113
    sources: Sequence[str],
114
    targets: Sequence[str],
115
    tokenizer: transformers.PreTrainedTokenizer,
116
) -> Dict:
117
    """Preprocess the data by tokenizing."""
118
    examples = [s + t for s, t in zip(sources, targets)]
119
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
120
    input_ids = examples_tokenized["input_ids"]
121
    labels = copy.deepcopy(input_ids)
122
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
123
        label[:source_len] = IGNORE_INDEX
124
    return dict(input_ids=input_ids, labels=labels)
125

126

127
class SupervisedDataset(Dataset):
128
    """Dataset for supervised fine-tuning."""
129

130
    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
131
        super(SupervisedDataset, self).__init__()
132
        logging.warning("Loading data...")
133
        list_data_dict = utils.jload(data_path)
134

135
        logging.warning("Formatting inputs...")
136
        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
137
        sources = [
138
            prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
139
            for example in list_data_dict
140
        ]
141
        targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
142

143
        logging.warning("Tokenizing inputs... This may take some time...")
144
        data_dict = preprocess(sources, targets, tokenizer)
145

146
        self.input_ids = data_dict["input_ids"]
147
        self.labels = data_dict["labels"]
148

149
    def __len__(self):
150
        return len(self.input_ids)
151

152
    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
153
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])
154

155

156
@dataclass
157
class DataCollatorForSupervisedDataset(object):
158
    """Collate examples for supervised fine-tuning."""
159

160
    tokenizer: transformers.PreTrainedTokenizer
161

162
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
163
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
164
        input_ids = torch.nn.utils.rnn.pad_sequence(
165
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
166
        )
167
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
168
        return dict(
169
            input_ids=input_ids,
170
            labels=labels,
171
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
172
        )
173

174

175
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
176
    """Make dataset and collator for supervised fine-tuning."""
177
    train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
178
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
179
    return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
180

181

182
def train():
183
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
184
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
185

186
    model = transformers.AutoModelForCausalLM.from_pretrained(
187
        model_args.model_name_or_path,
188
        cache_dir=training_args.cache_dir,
189
    )
190

191
    tokenizer = transformers.AutoTokenizer.from_pretrained(
192
        model_args.model_name_or_path,
193
        cache_dir=training_args.cache_dir,
194
        model_max_length=training_args.model_max_length,
195
        padding_side="right",
196
        use_fast=False,
197
    )
198
    special_tokens_dict = dict()
199
    if tokenizer.pad_token is None:
200
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
201
    if tokenizer.eos_token is None:
202
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
203
    if tokenizer.bos_token is None:
204
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
205
    if tokenizer.unk_token is None:
206
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
207

208
    smart_tokenizer_and_embedding_resize(
209
        special_tokens_dict=special_tokens_dict,
210
        tokenizer=tokenizer,
211
        model=model,
212
    )
213

214
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
215
    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
216
    trainer.train()
217
    trainer.save_state()
218
    trainer.save_model(output_dir=training_args.output_dir)
219

220

221
if __name__ == "__main__":
222
    train()
223

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.