3
from itertools import chain
4
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
6
from datasets import load_from_disk
8
from llmtuner.data.template import get_template_and_fix_tokenizer
9
from llmtuner.extras.constants import IGNORE_INDEX
10
from llmtuner.extras.logging import get_logger
13
from datasets import Dataset, IterableDataset
14
from transformers import Seq2SeqTrainingArguments
15
from transformers.tokenization_utils import PreTrainedTokenizer
16
from llmtuner.hparams import DataArguments
19
logger = get_logger(__name__)
22
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
23
for i in range(len(examples["prompt"])):
24
query, response = examples["prompt"][i], examples["response"][i]
25
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
26
history = examples["history"][i] if "history" in examples else None
27
system = examples["system"][i] if "system" in examples else None
28
yield query, response, history, system
31
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
32
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
33
max_target_len = max(max_target_len, data_args.reserved_label_len)
34
max_source_len = data_args.cutoff_len - max_target_len
35
return max_source_len, max_target_len
38
def preprocess_dataset(
39
dataset: Union["Dataset", "IterableDataset"],
40
tokenizer: "PreTrainedTokenizer",
41
data_args: "DataArguments",
42
training_args: "Seq2SeqTrainingArguments",
43
stage: Literal["pt", "sft", "rm", "ppo"]
44
) -> Union["Dataset", "IterableDataset"]:
45
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
47
if data_args.train_on_prompt and template.efficient_eos:
48
raise ValueError("Current template does not support `train_on_prompt`.")
50
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
51
# build grouped texts with format `X1 X2 X3 ...`
52
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
53
kwargs = dict(allowed_special="all")
55
kwargs = dict(add_special_tokens=True)
57
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
58
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
59
setattr(tokenizer, "add_eos_token", True)
61
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
62
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
63
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
64
block_size = data_args.cutoff_len
65
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
66
total_length = (total_length // block_size) * block_size
67
# split by chunks of cutoff_len
69
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
70
for k, t in concatenated_examples.items()
72
# make sure the saved tokenizer is the same as the original one
73
if hasattr(tokenizer, "add_eos_token"):
74
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
77
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
78
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
79
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
80
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
82
for query, response, history, system in construct_example(examples):
83
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
86
input_ids, labels = [], []
87
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
88
tokenizer, query, response, history, system
90
source_len, target_len = len(source_ids), len(target_ids)
91
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
92
if source_len > max_source_len:
93
source_ids = source_ids[:max_source_len]
94
if target_len > max_target_len:
95
target_ids = target_ids[:max_target_len]
97
if data_args.train_on_prompt:
98
source_mask = source_ids
99
elif turn_idx != 0 and template.efficient_eos:
100
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
102
source_mask = [IGNORE_INDEX] * len(source_ids)
104
input_ids += source_ids + target_ids
105
labels += source_mask + target_ids
107
if template.efficient_eos:
108
input_ids += [tokenizer.eos_token_id]
109
labels += [tokenizer.eos_token_id]
111
if len(input_ids) > data_args.cutoff_len:
112
input_ids = input_ids[:data_args.cutoff_len]
113
labels = labels[:data_args.cutoff_len]
115
model_inputs["input_ids"].append(input_ids)
116
model_inputs["attention_mask"].append([1] * len(input_ids))
117
model_inputs["labels"].append(labels)
121
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
122
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
123
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
124
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
125
input_ids, labels = [], []
126
for query, response, history, system in construct_example(examples):
127
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
130
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
131
tokenizer, query, response, history, system
133
if data_args.train_on_prompt:
134
source_mask = source_ids
135
elif turn_idx != 0 and template.efficient_eos:
136
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
138
source_mask = [IGNORE_INDEX] * len(source_ids)
139
input_ids += source_ids + target_ids
140
labels += source_mask + target_ids
142
if template.efficient_eos:
143
input_ids += [tokenizer.eos_token_id]
144
labels += [tokenizer.eos_token_id]
146
total_length = len(input_ids)
147
block_size = data_args.cutoff_len
148
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
149
total_length = (total_length // block_size) * block_size
150
# split by chunks of cutoff_len
151
for i in range(0, total_length, block_size):
152
model_inputs["input_ids"].append(input_ids[i: i + block_size])
153
model_inputs["attention_mask"].append([1] * block_size)
154
model_inputs["labels"].append(labels[i: i + block_size])
158
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
159
# build inputs with format `<bos> X` and labels with format `Y <eos>`
160
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
162
for query, response, history, system in construct_example(examples):
163
if not (isinstance(query, str) and query != ""):
166
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
168
if template.efficient_eos:
169
labels += [tokenizer.eos_token_id]
171
if len(input_ids) > data_args.cutoff_len:
172
input_ids = input_ids[:data_args.cutoff_len]
173
if len(labels) > data_args.cutoff_len:
174
labels = labels[:data_args.cutoff_len]
176
model_inputs["input_ids"].append(input_ids)
177
model_inputs["attention_mask"].append([1] * len(input_ids))
178
model_inputs["labels"].append(labels)
182
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
183
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
184
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
185
for query, response, history, system in construct_example(examples):
186
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
189
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
190
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
192
if template.efficient_eos:
193
chosen_ids += [tokenizer.eos_token_id]
194
rejected_ids += [tokenizer.eos_token_id]
196
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
197
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
198
if source_len > max_source_len:
199
prompt_ids = prompt_ids[:max_source_len]
200
if target_len > max_target_len:
201
chosen_ids = chosen_ids[:max_target_len]
202
rejected_ids = rejected_ids[:max_target_len]
204
model_inputs["prompt_ids"].append(prompt_ids)
205
model_inputs["chosen_ids"].append(chosen_ids)
206
model_inputs["rejected_ids"].append(rejected_ids)
210
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
211
print("input_ids:\n{}".format(example["input_ids"]))
212
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
213
print("label_ids:\n{}".format(example["labels"]))
214
print("labels:\n{}".format(
215
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
218
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
219
print("prompt_ids:\n{}".format(example["prompt_ids"]))
220
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
221
print("chosen_ids:\n{}".format(example["chosen_ids"]))
222
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
223
print("rejected_ids:\n{}".format(example["rejected_ids"]))
224
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
226
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
227
print("input_ids:\n{}".format(example["input_ids"]))
228
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
231
preprocess_func = preprocess_pretrain_dataset
232
print_function = print_unsupervised_dataset_example
233
elif stage == "sft" and not training_args.predict_with_generate:
234
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
235
print_function = print_supervised_dataset_example
237
preprocess_func = preprocess_pairwise_dataset
238
print_function = print_pairwise_dataset_example
240
preprocess_func = preprocess_unsupervised_dataset
241
print_function = print_unsupervised_dataset_example
243
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
244
logger.warning("Loading dataset from disk will ignore other data arguments.")
245
return load_from_disk(data_args.cache_path)
247
with training_args.main_process_first(desc="dataset map pre-processing"):
248
column_names = list(next(iter(dataset)).keys())
250
if not data_args.streaming:
252
num_proc=data_args.preprocessing_num_workers,
253
load_from_cache_file=(not data_args.overwrite_cache),
254
desc="Running tokenizer on dataset"
257
dataset = dataset.map(
260
remove_columns=column_names,
264
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
265
if training_args.should_save:
266
dataset.save_to_disk(data_args.cache_path)
267
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
269
if training_args.should_log:
271
print_function(next(iter(dataset)))
272
except StopIteration:
273
raise RuntimeError("Empty dataset!")