UltraChat

Форк
0
/
ultrachat_dataset.py 
147 строк · 5.3 Кб
1
import os
2
import json
3
from typing import *
4

5

6
import torch
7
from torch.utils.data import IterableDataset, Dataset
8
from tqdm import tqdm
9

10
from transformers.tokenization_utils import PreTrainedTokenizer
11
import copy
12

13

14
def load_single_file(data_file):
15
    with open(data_file)as f:
16
        lines = f.readlines()
17
    return [json.loads(l) for l in lines]
18

19
def load_raw_data(data_file):
20
    raw_dataset = []
21
    if isinstance(data_file, str):
22
        raw_dataset += load_single_file(data_file)
23
    elif isinstance(data_file, list):
24
        for f_ in data_file:
25
            raw_dataset += load_single_file(f_)
26
    return raw_dataset
27
    
28
IGNORE_INDEX=-100
29

30

31
def collator(tokenizer, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
32
    input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
33
    input_ids = torch.nn.utils.rnn.pad_sequence(
34
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
35
    )
36
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
37
    return dict(
38
        input_ids=input_ids,
39
        labels=labels,
40
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
41
    )
42

43

44
class PromptIterableDataset(IterableDataset):
45
    def __init__(self,
46
                 raw_dataset: Union[Dataset, List],
47
                 sep: List = ["EOS", "\n"],
48
                 tokenizer: PreTrainedTokenizer = None,
49
                 max_seq_length: Optional[int] = 512,
50
                 teacher_forcing: Optional[bool] = True,
51
                 truncate_method: Optional[str] = "tail",
52
                ):
53
        assert hasattr(raw_dataset, "__iter__"), f"The dataset must have __iter__ method. dataset is {raw_dataset}"
54
        assert hasattr(raw_dataset, "__len__"), f"The dataset must have __len__ method. dataset is {raw_dataset}"
55
        self.raw_dataset = raw_dataset
56
        self.sep = sep
57
        self._end_token = None
58
        self.start_token = self.sep[-1]
59
        self.teacher_forcing = teacher_forcing
60
        assert self.teacher_forcing, print("must use teacher forcing")
61

62
        self.tokenizer = tokenizer
63
        self.truncate_method = truncate_method
64
        self.max_seq_length = max_seq_length
65
        assert self.truncate_method == "tail", print("only tail truncate support")
66
    
67

68
    
69
    @property
70
    def end_token(self):
71
        if self._end_token is not None:
72
            return self._end_token
73
        end_token = self.sep[0]
74
        if end_token == "EOS":
75
            self._end_token = self.tokenizer.eos_token
76
        else:
77
            self._end_token = end_token
78
        return self._end_token
79

80
    def tokenize_example(self, example):
81
        end_token = self.end_token
82
        tags = [i for _ in range(len(example["data"])//2) for i in ["User", "Assistant"]]
83
        labels = []
84
        tokenized_ids = []
85
        for i, c in enumerate(example["data"]):
86
            c_new = tags[i] + ": " + c + end_token
87
            if i % 2 == 1:
88
                # model
89
                c_input = self.start_token + tags[i] + ": "
90
                tokenized = self.tokenizer(c_input, add_special_tokens=False)
91
                tokenized_ids += tokenized["input_ids"]
92
                labels += [IGNORE_INDEX] * len(tokenized["input_ids"])
93

94
                c_generate = c + end_token
95
                tokenized = self.tokenizer(c_generate, add_special_tokens=False)
96
                tokenized_ids += tokenized["input_ids"]
97
                labels += tokenized["input_ids"]
98

99
            else:
100
                # user
101
                if i == 0:
102
                    # no start token
103
                    c_new = self.tokenizer.bos_token + tags[i] + ": " + c + end_token
104
                else:
105
                    c_new = self.start_token + tags[i] + ": " + c + end_token
106
                tokenized = self.tokenizer(c_new, add_special_tokens=False)
107
                tokenized_ids += tokenized["input_ids"]
108
                labels += [IGNORE_INDEX] * len(tokenized["input_ids"])
109

110
        assert len(tokenized_ids) == len(labels)
111

112
        return {"input_ids": torch.LongTensor(tokenized_ids), "labels": torch.LongTensor(labels)}
113

114
    def truncate(self, tokenized_example):
115
        old_len = len(tokenized_example["input_ids"])
116
        if old_len > self.max_seq_length:
117
            for k in tokenized_example:
118
                tokenized_example[k] = tokenized_example[k][:-(old_len - self.max_seq_length)]
119

120
        return tokenized_example
121

122

123
    def __iter__(self):
124
        for example in self.raw_dataset:
125
            tokenized_example = self.tokenize_example(example)
126
            tokenized_example = self.truncate(tokenized_example)
127
            yield tokenized_example
128

129
    def __len__(self):
130
        return len(self.raw_dataset)
131

132

133
if __name__ == "__main__":
134
    from transformers import AutoTokenizer, LlamaTokenizer
135
    TEMPLATE = "{} Assistant:"
136
    tokenizer = LlamaTokenizer.from_pretrained("../../llama-7B-HF")
137
    raw_dataset = load_raw_data("../data/processed/part2_1.json")
138

139
    dataset = PromptIterableDataset(raw_dataset, tokenizer=tokenizer, max_seq_length=2048, teacher_forcing=True)
140
    for data in dataset:
141
        print(data)
142
        print(tokenizer.decode(data["input_ids"][:1000]))
143
        
144
        model_output = data["input_ids"][:1000][data["labels"][:1000]!=-100]
145
        print("##### model output")
146
        print(tokenizer.decode(model_output))
147
        break
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.