lmops

Форк
0
/
few_shot_dsr.py 
251 строка · 8.6 Кб
1
"""
2
for inference
3
"""
4
from typing import Any, Dict
5
from transformers import AutoTokenizer
6
import torch
7
import numpy as np
8
import json
9
import random
10
import re
11

12
import more_itertools
13
from src.utils.dataset_utils import pad2sameLen
14
from DPR.dpr.utils.tasks import task_map, get_prompt_files
15
from DPR.dpr.utils.data_utils import read_data_from_json_files
16

17

18
def remove_double_space(string):
19
    return re.sub("[ ]{2,}", " ", string)
20

21

22
class FewShotDatasetReader(torch.utils.data.Dataset):
23
    def __init__(
24
        self,
25
        model_name,
26
        task_name,
27
        prompt_file,
28
        prompt_pool_path,
29
        num_prompts=-1,
30
        n_tokens=1600,
31
        random_sample=False,
32
        random_seed=0,
33
        cache_dir=None,
34
        max_length=2048,
35
        train_clusters=None,
36
    ) -> None:
37
        self.task = task_map.cls_dic[task_name]()
38
        self.tokenizer = AutoTokenizer.from_pretrained(
39
            model_name, cache_dir=cache_dir, model_max_length=max_length
40
        )
41
        if self.task.class_num == 1:
42
            self.tokenizer.padding_side = "left"
43

44
        # retreived prompt_file,
45
        with open(prompt_file) as f:
46
            self.prompts = json.load(f)
47

48
        # random sample from prompt pool
49
        self.random_sample = random_sample
50
        if random_sample:
51
            prompt_pool_path = get_prompt_files(prompt_pool_path, train_clusters)
52
            print("prompt files: ", prompt_pool_path)
53
            self.prompt_pool = read_data_from_json_files(prompt_pool_path)
54
            print("prompt passages num : ", len(self.prompt_pool))
55
            self.random_seed = random_seed
56

57
        self.num_prompts = num_prompts
58
        self.n_tokens_in_prompt = n_tokens
59
        self.num_processes = 1
60
        self.process_index = 0
61

62
    def __getitem__(self, index):
63
        if self.task.class_num == 1:
64
            return self.text_to_instance_completion(self.prompts[index])
65
        else:
66
            return self.text_to_instance_choice(self.prompts[index])
67

68
    def __len__(self):
69
        return len(self.prompts)
70

71
    def shard(self, accelerator):
72
        self.num_processes = accelerator.num_processes
73
        self.process_index = accelerator.process_index
74
        self.prompts = list(
75
            more_itertools.distribute(accelerator.num_processes, self.prompts)[
76
                accelerator.process_index
77
            ]
78
        )
79

80
    def get_length(self, text):
81
        tokenized_example = self.tokenizer.encode_plus(
82
            text, truncation=False, return_tensors="pt"
83
        )
84
        shape = tokenized_example.input_ids.squeeze().shape
85
        if len(shape) == 0:
86
            return 1
87
        else:
88
            return int(shape[0])
89

90
    def format_prompt(self, entry):
91
        prompt_task = task_map.cls_dic[entry["task_name"]]()
92
        prompt_question = prompt_task.get_question(entry)
93
        prompt_answer = prompt_task.get_answer(entry)
94
        qa=f'{prompt_question}{prompt_answer}'
95
        return remove_double_space(qa)
96

97
    def get_fields(self, entry):
98
        example = entry["meta_data"]
99
        questions = self.task.get_input_strs(example)
100
        answers = self.task.get_answers(example)
101
        label = self.task.get_label(example)
102
        if self.num_prompts == 0:
103
            prompts_list = []
104
        elif self.random_sample:
105
            random.seed(self.random_seed)
106
            prompts_list = [
107
                p for p in random.choices(self.prompt_pool, k=len(entry["ctxs"]))
108
            ]
109
        else:
110
            prompts_list = [p["meta_data"] for p in entry["ctxs"]]
111
        lengths_list = [self.get_length(self.format_prompt(prompt)) for prompt in prompts_list]
112

113
        max_q_length = 0
114
        max_a_length = 0
115
        for i in range(len(questions)):
116
            max_q_length = max(
117
                max_q_length, self.get_length(remove_double_space(questions[i]))
118
            )
119
            max_a_length = max(
120
                max_a_length, self.get_length(remove_double_space(answers[i]))
121
            )
122

123
        max_prompts = np.searchsorted(
124
            np.cumsum(lengths_list),
125
            self.n_tokens_in_prompt - (max_q_length + max_a_length),
126
        )
127
        if self.num_prompts > -1:
128
            max_prompts = min(
129
                self.num_prompts, max_prompts
130
            )
131
        trunc_prompts_list = prompts_list[:max_prompts][::-1]
132

133
        
134
        prompt_enc_text = " \n ".join(
135
            [self.format_prompt(prompt) for prompt in trunc_prompts_list]
136
        )
137
        if max_prompts == 0:
138
            questions = [remove_double_space(question) for question in questions]
139
        else:
140
            questions = [
141
                remove_double_space(prompt_enc_text + " \n " + question)
142
                for question in questions
143
            ]
144
        return questions, answers, label, max_prompts
145

146
    def text_to_instance_choice(self, entry: Dict[str, Any]):
147
        """
148
        multiple-choice question
149
        """
150
        questions, answers, label, max_prompts = self.get_fields(entry)
151
        input_ids_list = []
152
        input_atten_mask_list = []
153
        input_loss_mask_list = []
154

155
        example = {}
156
        example["enc_text"] = []
157
        example["enc_answer"] = []
158
        for i in range(len(questions)):
159
            enc_text = remove_double_space(questions[i] + answers[i])
160
            example["enc_text"].append(
161
                remove_double_space(questions[i]).strip()
162
            )  # remove trailing space after question
163
            tokenized_example = self.tokenizer.encode_plus(
164
                enc_text,
165
                truncation=False,
166
                return_tensors="pt",
167
                add_special_tokens=False,
168
            )
169
            enc_answer = remove_double_space(remove_double_space(answers[i]))
170
            tokenized_answer = self.tokenizer.encode_plus(
171
                enc_answer,
172
                truncation=False,
173
                add_special_tokens=False,
174
                return_tensors="pt",
175
            )
176

177
            answer_mask = tokenized_answer.attention_mask.squeeze()
178
            if len(answer_mask.shape) == 0:
179
                answer_mask = torch.tensor([1]).to(answer_mask)
180

181
            input_ids = tokenized_example.input_ids.squeeze()
182
            input_atten_mask = tokenized_example.attention_mask.squeeze()
183
            input_loss_mask = torch.nn.functional.pad(
184
                answer_mask, (input_ids.shape[-1] - answer_mask.shape[-1], 0)
185
            )
186

187
            input_ids_list.append(input_ids)
188
            input_atten_mask_list.append(input_atten_mask)
189
            input_loss_mask_list.append(input_loss_mask)
190

191
            example["enc_answer"].append(enc_answer)
192
        example["n_prompts"] = str(max_prompts)
193
        example["label"] = label
194
        return {
195
            "input_ids": pad2sameLen(
196
                input_ids_list, pad_idx=self.tokenizer.pad_token_id
197
            ),
198
            "input_atten_mask": pad2sameLen(input_atten_mask_list, pad_idx=0),
199
            "input_loss_mask": pad2sameLen(input_loss_mask_list, pad_idx=0),
200
            "labels": torch.tensor([label]),
201
            "metadata": example,
202
        }
203

204
    def text_to_instance_completion(self, entry: Dict[str, Any]):
205
        """
206
        text completion question
207
        """
208
        questions, answers, label, max_prompts = self.get_fields(entry)
209

210
        input_ids_list = []
211
        input_atten_mask_list = []
212
        example = {}
213
        example["enc_text"] = []
214
        example["enc_answer"] = []
215
        for i in range(len(questions)):
216
            enc_text = remove_double_space(questions[i])
217
            example["enc_text"].append(
218
                remove_double_space(questions[i]).strip()
219
            )  # remove trailing space after question
220

221
            tokenized_example = self.tokenizer.encode_plus(
222
                enc_text,
223
                truncation=False,
224
                return_tensors="pt",
225
                add_special_tokens=False,
226
            )
227
            enc_answer = remove_double_space(remove_double_space(answers[i]))
228

229
            input_ids = tokenized_example.input_ids.squeeze()
230
            input_atten_mask = tokenized_example.attention_mask.squeeze()
231

232
            if len(input_ids.shape) == 0:
233
                input_ids = input_ids.unsqueeze(0)
234
                input_atten_mask = input_atten_mask.unsqueeze(0)
235

236
            input_ids_list.append(input_ids)
237
            input_atten_mask_list.append(input_atten_mask)
238

239
            example["enc_answer"].append(enc_answer)
240
        example["temp_label"] = label
241
        example["n_prompts"] = str(max_prompts)
242
        example["label"] = label
243
        return {
244
            "input_ids": pad2sameLen(
245
                input_ids_list, pad_idx=self.tokenizer.pad_token_id, left_pad=True
246
            ),
247
            "input_atten_mask": pad2sameLen(
248
                input_atten_mask_list, pad_idx=0, left_pad=True
249
            ),
250
            "metadata": example,
251
        }
252

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.