7
from torch.utils.data import IterableDataset, Dataset
10
from transformers.tokenization_utils import PreTrainedTokenizer
14
def load_single_file(data_file):
15
with open(data_file)as f:
17
return [json.loads(l) for l in lines]
19
def load_raw_data(data_file):
21
if isinstance(data_file, str):
22
raw_dataset += load_single_file(data_file)
23
elif isinstance(data_file, list):
25
raw_dataset += load_single_file(f_)
31
def collator(tokenizer, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
32
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
33
input_ids = torch.nn.utils.rnn.pad_sequence(
34
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
36
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
40
attention_mask=input_ids.ne(tokenizer.pad_token_id),
44
class PromptIterableDataset(IterableDataset):
46
raw_dataset: Union[Dataset, List],
47
sep: List = ["EOS", "\n"],
48
tokenizer: PreTrainedTokenizer = None,
49
max_seq_length: Optional[int] = 512,
50
teacher_forcing: Optional[bool] = True,
51
truncate_method: Optional[str] = "tail",
53
assert hasattr(raw_dataset, "__iter__"), f"The dataset must have __iter__ method. dataset is {raw_dataset}"
54
assert hasattr(raw_dataset, "__len__"), f"The dataset must have __len__ method. dataset is {raw_dataset}"
55
self.raw_dataset = raw_dataset
57
self._end_token = None
58
self.start_token = self.sep[-1]
59
self.teacher_forcing = teacher_forcing
60
assert self.teacher_forcing, print("must use teacher forcing")
62
self.tokenizer = tokenizer
63
self.truncate_method = truncate_method
64
self.max_seq_length = max_seq_length
65
assert self.truncate_method == "tail", print("only tail truncate support")
71
if self._end_token is not None:
72
return self._end_token
73
end_token = self.sep[0]
74
if end_token == "EOS":
75
self._end_token = self.tokenizer.eos_token
77
self._end_token = end_token
78
return self._end_token
80
def tokenize_example(self, example):
81
end_token = self.end_token
82
tags = [i for _ in range(len(example["data"])//2) for i in ["User", "Assistant"]]
85
for i, c in enumerate(example["data"]):
86
c_new = tags[i] + ": " + c + end_token
89
c_input = self.start_token + tags[i] + ": "
90
tokenized = self.tokenizer(c_input, add_special_tokens=False)
91
tokenized_ids += tokenized["input_ids"]
92
labels += [IGNORE_INDEX] * len(tokenized["input_ids"])
94
c_generate = c + end_token
95
tokenized = self.tokenizer(c_generate, add_special_tokens=False)
96
tokenized_ids += tokenized["input_ids"]
97
labels += tokenized["input_ids"]
103
c_new = self.tokenizer.bos_token + tags[i] + ": " + c + end_token
105
c_new = self.start_token + tags[i] + ": " + c + end_token
106
tokenized = self.tokenizer(c_new, add_special_tokens=False)
107
tokenized_ids += tokenized["input_ids"]
108
labels += [IGNORE_INDEX] * len(tokenized["input_ids"])
110
assert len(tokenized_ids) == len(labels)
112
return {"input_ids": torch.LongTensor(tokenized_ids), "labels": torch.LongTensor(labels)}
114
def truncate(self, tokenized_example):
115
old_len = len(tokenized_example["input_ids"])
116
if old_len > self.max_seq_length:
117
for k in tokenized_example:
118
tokenized_example[k] = tokenized_example[k][:-(old_len - self.max_seq_length)]
120
return tokenized_example
124
for example in self.raw_dataset:
125
tokenized_example = self.tokenize_example(example)
126
tokenized_example = self.truncate(tokenized_example)
127
yield tokenized_example
130
return len(self.raw_dataset)
133
if __name__ == "__main__":
134
from transformers import AutoTokenizer, LlamaTokenizer
135
TEMPLATE = "{} Assistant:"
136
tokenizer = LlamaTokenizer.from_pretrained("../../llama-7B-HF")
137
raw_dataset = load_raw_data("../data/processed/part2_1.json")
139
dataset = PromptIterableDataset(raw_dataset, tokenizer=tokenizer, max_seq_length=2048, teacher_forcing=True)
142
print(tokenizer.decode(data["input_ids"][:1000]))
144
model_output = data["input_ids"][:1000][data["labels"][:1000]!=-100]
145
print("##### model output")
146
print(tokenizer.decode(model_output))