aurora

Форк
0
/
preprocess.py 
275 строк · 13.0 Кб
1
import os
2
import tiktoken
3
from itertools import chain
4
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
5

6
from datasets import load_from_disk
7

8
from llmtuner.data.template import get_template_and_fix_tokenizer
9
from llmtuner.extras.constants import IGNORE_INDEX
10
from llmtuner.extras.logging import get_logger
11

12
if TYPE_CHECKING:
13
    from datasets import Dataset, IterableDataset
14
    from transformers import Seq2SeqTrainingArguments
15
    from transformers.tokenization_utils import PreTrainedTokenizer
16
    from llmtuner.hparams import DataArguments
17

18

19
logger = get_logger(__name__)
20

21

22
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
23
    for i in range(len(examples["prompt"])):
24
        query, response = examples["prompt"][i], examples["response"][i]
25
        query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
26
        history = examples["history"][i] if "history" in examples else None
27
        system = examples["system"][i] if "system" in examples else None
28
        yield query, response, history, system
29

30

31
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
32
    max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
33
    max_target_len = max(max_target_len, data_args.reserved_label_len)
34
    max_source_len = data_args.cutoff_len - max_target_len
35
    return max_source_len, max_target_len
36

37

38
def preprocess_dataset(
39
    dataset: Union["Dataset", "IterableDataset"],
40
    tokenizer: "PreTrainedTokenizer",
41
    data_args: "DataArguments",
42
    training_args: "Seq2SeqTrainingArguments",
43
    stage: Literal["pt", "sft", "rm", "ppo"]
44
) -> Union["Dataset", "IterableDataset"]:
45
    template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
46

47
    if data_args.train_on_prompt and template.efficient_eos:
48
        raise ValueError("Current template does not support `train_on_prompt`.")
49

50
    def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
51
        # build grouped texts with format `X1 X2 X3 ...`
52
        if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
53
            kwargs = dict(allowed_special="all")
54
        else:
55
            kwargs = dict(add_special_tokens=True)
56

57
        if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
58
            add_eos_token_flag = getattr(tokenizer, "add_eos_token")
59
            setattr(tokenizer, "add_eos_token", True)
60

61
        tokenized_examples = tokenizer(examples["prompt"], **kwargs)
62
        concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
63
        total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
64
        block_size = data_args.cutoff_len
65
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
66
        total_length = (total_length // block_size) * block_size
67
        # split by chunks of cutoff_len
68
        result = {
69
            k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
70
            for k, t in concatenated_examples.items()
71
        }
72
        # make sure the saved tokenizer is the same as the original one
73
        if hasattr(tokenizer, "add_eos_token"):
74
            setattr(tokenizer, "add_eos_token", add_eos_token_flag)
75
        return result
76

77
    def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
78
        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
79
        # for multiturn examples, we only mask the prompt part in each prompt-response pair.
80
        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
81

82
        for query, response, history, system in construct_example(examples):
83
            if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
84
                continue
85

86
            input_ids, labels = [], []
87
            for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
88
                tokenizer, query, response, history, system
89
            )):
90
                source_len, target_len = len(source_ids), len(target_ids)
91
                max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
92
                if source_len > max_source_len:
93
                    source_ids = source_ids[:max_source_len]
94
                if target_len > max_target_len:
95
                    target_ids = target_ids[:max_target_len]
96

97
                if data_args.train_on_prompt:
98
                    source_mask = source_ids
99
                elif turn_idx != 0 and template.efficient_eos:
100
                    source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
101
                else:
102
                    source_mask = [IGNORE_INDEX] * len(source_ids)
103

104
                input_ids += source_ids + target_ids
105
                labels += source_mask + target_ids
106

107
            if template.efficient_eos:
108
                input_ids += [tokenizer.eos_token_id]
109
                labels += [tokenizer.eos_token_id]
110

111
            if len(input_ids) > data_args.cutoff_len:
112
                input_ids = input_ids[:data_args.cutoff_len]
113
                labels = labels[:data_args.cutoff_len]
114

115
            model_inputs["input_ids"].append(input_ids)
116
            model_inputs["attention_mask"].append([1] * len(input_ids))
117
            model_inputs["labels"].append(labels)
118

119
        return model_inputs
120

121
    def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
122
        # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
123
        # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
124
        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
125
        input_ids, labels = [], []
126
        for query, response, history, system in construct_example(examples):
127
            if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
128
                continue
129

130
            for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
131
                tokenizer, query, response, history, system
132
            )):
133
                if data_args.train_on_prompt:
134
                    source_mask = source_ids
135
                elif turn_idx != 0 and template.efficient_eos:
136
                    source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
137
                else:
138
                    source_mask = [IGNORE_INDEX] * len(source_ids)
139
                input_ids += source_ids + target_ids
140
                labels += source_mask + target_ids
141

142
        if template.efficient_eos:
143
            input_ids += [tokenizer.eos_token_id]
144
            labels += [tokenizer.eos_token_id]
145

146
        total_length = len(input_ids)
147
        block_size = data_args.cutoff_len
148
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
149
        total_length = (total_length // block_size) * block_size
150
        # split by chunks of cutoff_len
151
        for i in range(0, total_length, block_size):
152
            model_inputs["input_ids"].append(input_ids[i: i + block_size])
153
            model_inputs["attention_mask"].append([1] * block_size)
154
            model_inputs["labels"].append(labels[i: i + block_size])
155

156
        return model_inputs
157

158
    def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
159
        # build inputs with format `<bos> X` and labels with format `Y <eos>`
160
        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
161

162
        for query, response, history, system in construct_example(examples):
163
            if not (isinstance(query, str) and query != ""):
164
                continue
165

166
            input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
167

168
            if template.efficient_eos:
169
                labels += [tokenizer.eos_token_id]
170

171
            if len(input_ids) > data_args.cutoff_len:
172
                input_ids = input_ids[:data_args.cutoff_len]
173
            if len(labels) > data_args.cutoff_len:
174
                labels = labels[:data_args.cutoff_len]
175

176
            model_inputs["input_ids"].append(input_ids)
177
            model_inputs["attention_mask"].append([1] * len(input_ids))
178
            model_inputs["labels"].append(labels)
179

180
        return model_inputs
181

182
    def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
183
        # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
184
        model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
185
        for query, response, history, system in construct_example(examples):
186
            if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
187
                continue
188

189
            prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
190
            _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
191

192
            if template.efficient_eos:
193
                chosen_ids += [tokenizer.eos_token_id]
194
                rejected_ids += [tokenizer.eos_token_id]
195

196
            source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
197
            max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
198
            if source_len > max_source_len:
199
                prompt_ids = prompt_ids[:max_source_len]
200
            if target_len > max_target_len:
201
                chosen_ids = chosen_ids[:max_target_len]
202
                rejected_ids = rejected_ids[:max_target_len]
203

204
            model_inputs["prompt_ids"].append(prompt_ids)
205
            model_inputs["chosen_ids"].append(chosen_ids)
206
            model_inputs["rejected_ids"].append(rejected_ids)
207

208
        return model_inputs
209

210
    def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
211
        print("input_ids:\n{}".format(example["input_ids"]))
212
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
213
        print("label_ids:\n{}".format(example["labels"]))
214
        print("labels:\n{}".format(
215
            tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
216
        ))
217

218
    def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
219
        print("prompt_ids:\n{}".format(example["prompt_ids"]))
220
        print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
221
        print("chosen_ids:\n{}".format(example["chosen_ids"]))
222
        print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
223
        print("rejected_ids:\n{}".format(example["rejected_ids"]))
224
        print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
225

226
    def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
227
        print("input_ids:\n{}".format(example["input_ids"]))
228
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
229

230
    if stage == "pt":
231
        preprocess_func = preprocess_pretrain_dataset
232
        print_function = print_unsupervised_dataset_example
233
    elif stage == "sft" and not training_args.predict_with_generate:
234
        preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
235
        print_function = print_supervised_dataset_example
236
    elif stage == "rm":
237
        preprocess_func = preprocess_pairwise_dataset
238
        print_function = print_pairwise_dataset_example
239
    else:
240
        preprocess_func = preprocess_unsupervised_dataset
241
        print_function = print_unsupervised_dataset_example
242

243
    if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
244
        logger.warning("Loading dataset from disk will ignore other data arguments.")
245
        return load_from_disk(data_args.cache_path)
246

247
    with training_args.main_process_first(desc="dataset map pre-processing"):
248
        column_names = list(next(iter(dataset)).keys())
249
        kwargs = {}
250
        if not data_args.streaming:
251
            kwargs = dict(
252
                num_proc=data_args.preprocessing_num_workers,
253
                load_from_cache_file=(not data_args.overwrite_cache),
254
                desc="Running tokenizer on dataset"
255
            )
256

257
        dataset = dataset.map(
258
            preprocess_func,
259
            batched=True,
260
            remove_columns=column_names,
261
            **kwargs
262
        )
263

264
        if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
265
            if training_args.should_save:
266
                dataset.save_to_disk(data_args.cache_path)
267
            raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
268

269
        if training_args.should_log:
270
            try:
271
                print_function(next(iter(dataset)))
272
            except StopIteration:
273
                raise RuntimeError("Empty dataset!")
274

275
        return dataset
276

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.