lmops
170 строк · 5.8 Кб
1from typing import Any, Dict
2from transformers import AutoTokenizer
3import torch
4import pandas as pd
5import json
6import pandas as pd
7from datasets import Dataset
8import json
9import re
10from src.utils.dataset_utils import pad2sameLen
11from DPR.dpr.utils.tasks import task_map
12
13def remove_double_space(string):
14return re.sub("[ ]{2,}", " ", string)
15
16
17class ScorerDatasetReader(torch.utils.data.Dataset):
18def __init__(
19self,
20example_file,
21model_name,
22task_name,
23prompt_pool_path=None,
24cache_dir=None,
25max_length=2048,
26) -> None:
27self.task = task_map.cls_dic[task_name]()
28self.tokenizer = AutoTokenizer.from_pretrained(
29model_name, cache_dir=cache_dir, model_max_length=max_length
30)
31if self.task.class_num == 1: # text completion question
32self.tokenizer.padding_side = "left"
33
34# prompt_pool
35with open(prompt_pool_path, "r", encoding="utf-8") as f:
36prompt_pool = json.load(f)
37self.prompt_pool = list(enumerate(prompt_pool))
38
39
40# task_data
41with open(example_file) as f1:
42self.task_data = json.load(f1)
43
44def get_instance(entry):
45examples = entry.pop("ctxs")
46for exp in examples:
47exp.update(self.prompt_pool[exp["id"]][1])
48for key, val in entry.items():
49exp[f"test_{key}"] = val
50yield from examples
51
52def get_dataset(data):
53for entry in data:
54yield from get_instance(entry)
55
56df = pd.DataFrame(list(get_dataset(self.task_data)))
57self.dataset = Dataset.from_pandas(df)
58
59def shard(self, accelerator):
60self.dataset = self.dataset.shard(
61num_shards=accelerator.num_processes, index=accelerator.process_index
62)
63
64def __getitem__(self, index):
65if self.task.class_num == 1: # text completion question
66return self.text_to_instance_completion(self.dataset[index])
67else:
68return self.text_to_instance_choice(self.dataset[index])
69
70def __len__(self):
71return len(self.dataset)
72
73def get_fields(self, entry):
74example = {}
75for key, val in entry.items():
76if key.startswith("test_"):
77example[key[len("test_") :]] = val
78
79test_input_strs = self.task.get_input_strs(example)
80question = self.task.get_question(entry)
81answer = self.task.get_answer(entry)
82demonstration = f'{question}{answer}'
83test_questions = [demonstration + " \n " + input for input in test_input_strs]
84test_answer_strs = self.task.get_answers(example)
85test_label = self.task.get_label(example)
86return test_questions, test_answer_strs, test_label
87
88def text_to_instance_choice(self, entry):
89"""
90multiple-choice question
91"""
92test_questions, test_answers, test_label = self.get_fields(entry)
93
94input_ids_list = []
95input_atten_mask_list = []
96input_loss_mask_list = []
97for i in range(len(test_questions)):
98enc_text = remove_double_space(test_questions[i] + test_answers[i])
99enc_answer = remove_double_space(test_answers[i])
100tokenized_example = self.tokenizer.encode_plus(
101enc_text,
102truncation=False,
103add_special_tokens=False,
104return_tensors="pt",
105)
106tokenized_answer = self.tokenizer.encode_plus(
107enc_answer,
108truncation=False,
109add_special_tokens=False,
110return_tensors="pt",
111)
112
113
114answer_mask = tokenized_answer.attention_mask.squeeze()
115if len(answer_mask.shape) == 0:
116answer_mask = torch.tensor([1]).to(answer_mask)
117
118input_ids = tokenized_example.input_ids.squeeze()
119input_atten_mask = tokenized_example.attention_mask.squeeze()
120input_loss_mask = torch.nn.functional.pad(
121answer_mask, (input_ids.shape[-1] - answer_mask.shape[-1], 0)
122)
123
124input_ids_list.append(input_ids)
125input_atten_mask_list.append(input_atten_mask)
126input_loss_mask_list.append(input_loss_mask)
127
128return {
129"input_ids": pad2sameLen(
130input_ids_list, pad_idx=self.tokenizer.pad_token_id
131),
132"input_atten_mask": pad2sameLen(input_atten_mask_list, pad_idx=0),
133"input_loss_mask": pad2sameLen(input_loss_mask_list, pad_idx=0),
134"labels": torch.tensor([test_label]),
135"metadata": entry,
136}
137
138def text_to_instance_completion(self, entry: Dict[str, Any]):
139"""
140text completion question
141"""
142test_questions, _, test_label = self.get_fields(entry)
143
144input_ids_list = []
145input_atten_mask_list = []
146for i in range(len(test_questions)): # len(test_questions) = 1 for completion question
147enc_text = remove_double_space(test_questions[i]).strip()
148tokenized_example = self.tokenizer.encode_plus(
149enc_text,
150truncation=False,
151return_tensors="pt",
152add_special_tokens=False,
153)
154
155input_ids = tokenized_example.input_ids.squeeze()
156input_atten_mask = tokenized_example.attention_mask.squeeze()
157
158input_ids_list.append(input_ids)
159input_atten_mask_list.append(input_atten_mask)
160
161entry["temp_label"] = test_label # pass label for the next step
162return {
163"input_ids": pad2sameLen(
164input_ids_list, pad_idx=self.tokenizer.pad_token_id, left_pad=True
165),
166"input_atten_mask": pad2sameLen(
167input_atten_mask_list, pad_idx=0, left_pad=True
168),
169"metadata": entry,
170}
171