lmops

Форк
0
170 строк · 5.8 Кб
1
from typing import Any, Dict
2
from transformers import AutoTokenizer
3
import torch
4
import pandas as pd
5
import json
6
import pandas as pd
7
from datasets import Dataset
8
import json
9
import re
10
from src.utils.dataset_utils import pad2sameLen
11
from DPR.dpr.utils.tasks import task_map
12

13
def remove_double_space(string):
14
    return re.sub("[ ]{2,}", " ", string)
15

16

17
class ScorerDatasetReader(torch.utils.data.Dataset):
18
    def __init__(
19
        self,
20
        example_file,
21
        model_name,
22
        task_name,
23
        prompt_pool_path=None,
24
        cache_dir=None,
25
        max_length=2048,
26
    ) -> None:
27
        self.task = task_map.cls_dic[task_name]()
28
        self.tokenizer = AutoTokenizer.from_pretrained(
29
            model_name, cache_dir=cache_dir, model_max_length=max_length
30
        )
31
        if self.task.class_num == 1:  # text completion question
32
            self.tokenizer.padding_side = "left"
33

34
        # prompt_pool
35
        with open(prompt_pool_path, "r", encoding="utf-8") as f:
36
            prompt_pool = json.load(f)
37
        self.prompt_pool = list(enumerate(prompt_pool))
38
        
39

40
        # task_data
41
        with open(example_file) as f1:
42
            self.task_data = json.load(f1)
43

44
        def get_instance(entry):
45
            examples = entry.pop("ctxs")
46
            for exp in examples:
47
                exp.update(self.prompt_pool[exp["id"]][1])
48
                for key, val in entry.items():
49
                    exp[f"test_{key}"] = val
50
            yield from examples
51

52
        def get_dataset(data):
53
            for entry in data:
54
                yield from get_instance(entry)
55

56
        df = pd.DataFrame(list(get_dataset(self.task_data)))
57
        self.dataset = Dataset.from_pandas(df)
58

59
    def shard(self, accelerator):
60
        self.dataset = self.dataset.shard(
61
            num_shards=accelerator.num_processes, index=accelerator.process_index
62
        )
63

64
    def __getitem__(self, index):
65
        if self.task.class_num == 1: # text completion question
66
            return self.text_to_instance_completion(self.dataset[index])
67
        else:
68
            return self.text_to_instance_choice(self.dataset[index])
69

70
    def __len__(self):
71
        return len(self.dataset)
72

73
    def get_fields(self, entry):
74
        example = {}
75
        for key, val in entry.items():
76
            if key.startswith("test_"):
77
                example[key[len("test_") :]] = val
78

79
        test_input_strs = self.task.get_input_strs(example)
80
        question = self.task.get_question(entry)
81
        answer = self.task.get_answer(entry)
82
        demonstration = f'{question}{answer}'
83
        test_questions = [demonstration + " \n " + input for input in test_input_strs]
84
        test_answer_strs = self.task.get_answers(example)
85
        test_label = self.task.get_label(example)
86
        return test_questions, test_answer_strs, test_label
87

88
    def text_to_instance_choice(self, entry):
89
        """
90
        multiple-choice question
91
        """
92
        test_questions, test_answers, test_label = self.get_fields(entry)  
93

94
        input_ids_list = []
95
        input_atten_mask_list = []
96
        input_loss_mask_list = []
97
        for i in range(len(test_questions)):
98
            enc_text = remove_double_space(test_questions[i] + test_answers[i])
99
            enc_answer = remove_double_space(test_answers[i])
100
            tokenized_example = self.tokenizer.encode_plus(
101
                enc_text,
102
                truncation=False,
103
                add_special_tokens=False,
104
                return_tensors="pt",
105
            )
106
            tokenized_answer = self.tokenizer.encode_plus(
107
                enc_answer,
108
                truncation=False,
109
                add_special_tokens=False,
110
                return_tensors="pt",
111
            )
112

113

114
            answer_mask = tokenized_answer.attention_mask.squeeze()
115
            if len(answer_mask.shape) == 0:
116
                answer_mask = torch.tensor([1]).to(answer_mask)
117

118
            input_ids = tokenized_example.input_ids.squeeze()
119
            input_atten_mask = tokenized_example.attention_mask.squeeze()
120
            input_loss_mask = torch.nn.functional.pad(
121
                answer_mask, (input_ids.shape[-1] - answer_mask.shape[-1], 0)
122
            )
123

124
            input_ids_list.append(input_ids)
125
            input_atten_mask_list.append(input_atten_mask)
126
            input_loss_mask_list.append(input_loss_mask)
127

128
        return {
129
            "input_ids": pad2sameLen(
130
                input_ids_list, pad_idx=self.tokenizer.pad_token_id
131
            ),
132
            "input_atten_mask": pad2sameLen(input_atten_mask_list, pad_idx=0),
133
            "input_loss_mask": pad2sameLen(input_loss_mask_list, pad_idx=0),
134
            "labels": torch.tensor([test_label]),
135
            "metadata": entry,
136
        }
137

138
    def text_to_instance_completion(self, entry: Dict[str, Any]):
139
        """
140
        text completion question
141
        """
142
        test_questions, _, test_label = self.get_fields(entry)
143

144
        input_ids_list = []
145
        input_atten_mask_list = []
146
        for i in range(len(test_questions)): # len(test_questions) = 1 for completion question
147
            enc_text = remove_double_space(test_questions[i]).strip() 
148
            tokenized_example = self.tokenizer.encode_plus(
149
                enc_text,
150
                truncation=False,
151
                return_tensors="pt",
152
                add_special_tokens=False,
153
            )
154

155
            input_ids = tokenized_example.input_ids.squeeze()
156
            input_atten_mask = tokenized_example.attention_mask.squeeze()
157

158
            input_ids_list.append(input_ids)
159
            input_atten_mask_list.append(input_atten_mask)
160

161
        entry["temp_label"] = test_label  # pass label for the next step
162
        return {
163
            "input_ids": pad2sameLen(
164
                input_ids_list, pad_idx=self.tokenizer.pad_token_id, left_pad=True
165
            ),
166
            "input_atten_mask": pad2sameLen(
167
                input_atten_mask_list, pad_idx=0, left_pad=True
168
            ),
169
            "metadata": entry,
170
        }
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.