lmops
251 строка · 8.6 Кб
1"""
2for inference
3"""
4from typing import Any, Dict
5from transformers import AutoTokenizer
6import torch
7import numpy as np
8import json
9import random
10import re
11
12import more_itertools
13from src.utils.dataset_utils import pad2sameLen
14from DPR.dpr.utils.tasks import task_map, get_prompt_files
15from DPR.dpr.utils.data_utils import read_data_from_json_files
16
17
18def remove_double_space(string):
19return re.sub("[ ]{2,}", " ", string)
20
21
22class FewShotDatasetReader(torch.utils.data.Dataset):
23def __init__(
24self,
25model_name,
26task_name,
27prompt_file,
28prompt_pool_path,
29num_prompts=-1,
30n_tokens=1600,
31random_sample=False,
32random_seed=0,
33cache_dir=None,
34max_length=2048,
35train_clusters=None,
36) -> None:
37self.task = task_map.cls_dic[task_name]()
38self.tokenizer = AutoTokenizer.from_pretrained(
39model_name, cache_dir=cache_dir, model_max_length=max_length
40)
41if self.task.class_num == 1:
42self.tokenizer.padding_side = "left"
43
44# retreived prompt_file,
45with open(prompt_file) as f:
46self.prompts = json.load(f)
47
48# random sample from prompt pool
49self.random_sample = random_sample
50if random_sample:
51prompt_pool_path = get_prompt_files(prompt_pool_path, train_clusters)
52print("prompt files: ", prompt_pool_path)
53self.prompt_pool = read_data_from_json_files(prompt_pool_path)
54print("prompt passages num : ", len(self.prompt_pool))
55self.random_seed = random_seed
56
57self.num_prompts = num_prompts
58self.n_tokens_in_prompt = n_tokens
59self.num_processes = 1
60self.process_index = 0
61
62def __getitem__(self, index):
63if self.task.class_num == 1:
64return self.text_to_instance_completion(self.prompts[index])
65else:
66return self.text_to_instance_choice(self.prompts[index])
67
68def __len__(self):
69return len(self.prompts)
70
71def shard(self, accelerator):
72self.num_processes = accelerator.num_processes
73self.process_index = accelerator.process_index
74self.prompts = list(
75more_itertools.distribute(accelerator.num_processes, self.prompts)[
76accelerator.process_index
77]
78)
79
80def get_length(self, text):
81tokenized_example = self.tokenizer.encode_plus(
82text, truncation=False, return_tensors="pt"
83)
84shape = tokenized_example.input_ids.squeeze().shape
85if len(shape) == 0:
86return 1
87else:
88return int(shape[0])
89
90def format_prompt(self, entry):
91prompt_task = task_map.cls_dic[entry["task_name"]]()
92prompt_question = prompt_task.get_question(entry)
93prompt_answer = prompt_task.get_answer(entry)
94qa=f'{prompt_question}{prompt_answer}'
95return remove_double_space(qa)
96
97def get_fields(self, entry):
98example = entry["meta_data"]
99questions = self.task.get_input_strs(example)
100answers = self.task.get_answers(example)
101label = self.task.get_label(example)
102if self.num_prompts == 0:
103prompts_list = []
104elif self.random_sample:
105random.seed(self.random_seed)
106prompts_list = [
107p for p in random.choices(self.prompt_pool, k=len(entry["ctxs"]))
108]
109else:
110prompts_list = [p["meta_data"] for p in entry["ctxs"]]
111lengths_list = [self.get_length(self.format_prompt(prompt)) for prompt in prompts_list]
112
113max_q_length = 0
114max_a_length = 0
115for i in range(len(questions)):
116max_q_length = max(
117max_q_length, self.get_length(remove_double_space(questions[i]))
118)
119max_a_length = max(
120max_a_length, self.get_length(remove_double_space(answers[i]))
121)
122
123max_prompts = np.searchsorted(
124np.cumsum(lengths_list),
125self.n_tokens_in_prompt - (max_q_length + max_a_length),
126)
127if self.num_prompts > -1:
128max_prompts = min(
129self.num_prompts, max_prompts
130)
131trunc_prompts_list = prompts_list[:max_prompts][::-1]
132
133
134prompt_enc_text = " \n ".join(
135[self.format_prompt(prompt) for prompt in trunc_prompts_list]
136)
137if max_prompts == 0:
138questions = [remove_double_space(question) for question in questions]
139else:
140questions = [
141remove_double_space(prompt_enc_text + " \n " + question)
142for question in questions
143]
144return questions, answers, label, max_prompts
145
146def text_to_instance_choice(self, entry: Dict[str, Any]):
147"""
148multiple-choice question
149"""
150questions, answers, label, max_prompts = self.get_fields(entry)
151input_ids_list = []
152input_atten_mask_list = []
153input_loss_mask_list = []
154
155example = {}
156example["enc_text"] = []
157example["enc_answer"] = []
158for i in range(len(questions)):
159enc_text = remove_double_space(questions[i] + answers[i])
160example["enc_text"].append(
161remove_double_space(questions[i]).strip()
162) # remove trailing space after question
163tokenized_example = self.tokenizer.encode_plus(
164enc_text,
165truncation=False,
166return_tensors="pt",
167add_special_tokens=False,
168)
169enc_answer = remove_double_space(remove_double_space(answers[i]))
170tokenized_answer = self.tokenizer.encode_plus(
171enc_answer,
172truncation=False,
173add_special_tokens=False,
174return_tensors="pt",
175)
176
177answer_mask = tokenized_answer.attention_mask.squeeze()
178if len(answer_mask.shape) == 0:
179answer_mask = torch.tensor([1]).to(answer_mask)
180
181input_ids = tokenized_example.input_ids.squeeze()
182input_atten_mask = tokenized_example.attention_mask.squeeze()
183input_loss_mask = torch.nn.functional.pad(
184answer_mask, (input_ids.shape[-1] - answer_mask.shape[-1], 0)
185)
186
187input_ids_list.append(input_ids)
188input_atten_mask_list.append(input_atten_mask)
189input_loss_mask_list.append(input_loss_mask)
190
191example["enc_answer"].append(enc_answer)
192example["n_prompts"] = str(max_prompts)
193example["label"] = label
194return {
195"input_ids": pad2sameLen(
196input_ids_list, pad_idx=self.tokenizer.pad_token_id
197),
198"input_atten_mask": pad2sameLen(input_atten_mask_list, pad_idx=0),
199"input_loss_mask": pad2sameLen(input_loss_mask_list, pad_idx=0),
200"labels": torch.tensor([label]),
201"metadata": example,
202}
203
204def text_to_instance_completion(self, entry: Dict[str, Any]):
205"""
206text completion question
207"""
208questions, answers, label, max_prompts = self.get_fields(entry)
209
210input_ids_list = []
211input_atten_mask_list = []
212example = {}
213example["enc_text"] = []
214example["enc_answer"] = []
215for i in range(len(questions)):
216enc_text = remove_double_space(questions[i])
217example["enc_text"].append(
218remove_double_space(questions[i]).strip()
219) # remove trailing space after question
220
221tokenized_example = self.tokenizer.encode_plus(
222enc_text,
223truncation=False,
224return_tensors="pt",
225add_special_tokens=False,
226)
227enc_answer = remove_double_space(remove_double_space(answers[i]))
228
229input_ids = tokenized_example.input_ids.squeeze()
230input_atten_mask = tokenized_example.attention_mask.squeeze()
231
232if len(input_ids.shape) == 0:
233input_ids = input_ids.unsqueeze(0)
234input_atten_mask = input_atten_mask.unsqueeze(0)
235
236input_ids_list.append(input_ids)
237input_atten_mask_list.append(input_atten_mask)
238
239example["enc_answer"].append(enc_answer)
240example["temp_label"] = label
241example["n_prompts"] = str(max_prompts)
242example["label"] = label
243return {
244"input_ids": pad2sameLen(
245input_ids_list, pad_idx=self.tokenizer.pad_token_id, left_pad=True
246),
247"input_atten_mask": pad2sameLen(
248input_atten_mask_list, pad_idx=0, left_pad=True
249),
250"metadata": example,
251}
252