3
from torch.utils.data import DataLoader
4
from src.data.collators import DataCollatorWithPaddingAndCuda
5
import hydra.utils as hu
9
from omegaconf import OmegaConf
11
from src.utils.cache_util import BufferedJsonWriter, BufferedJsonReader
12
from src.utils.metric import metric_dict
13
from accelerate import Accelerator
16
from transformers import AutoModelForCausalLM
17
logger = logging.getLogger(__name__)
21
def __init__(self,cfg, accelerator) -> None:
22
self.dataset_reader = hu.instantiate(cfg.dataset_reader)
23
self.dataset_reader.shard(accelerator)
24
co = DataCollatorWithPaddingAndCuda(tokenizer=self.dataset_reader.tokenizer,device=accelerator.device)
25
self.dataloader = DataLoader(self.dataset_reader,batch_size=cfg.batch_size,collate_fn=co)
26
self.dataset_reader.tokenizer.pad_token_id = self.dataset_reader.tokenizer.eos_token_id
28
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=cfg.model_name, cache_dir=cfg.cache_dir)
29
self.output_train_file = cfg.output_train_file
30
self.output_valid_file = cfg.output_valid_file
31
self.accelerator = accelerator
33
self.model = self.model.half().to(self.accelerator.device)
34
self.model = self.model.eval()
36
self.tokenizer=self.dataset_reader.tokenizer
37
self.option_num=self.dataset_reader.task.class_num
39
os.makedirs(os.path.dirname(cfg.output_train_file), exist_ok=True)
41
self.max_length=cfg.max_length #used for text completion task,
42
self.generate_max_len=cfg.generate_max_len # max seq len to be generated
44
def choice_losses(self,input_ids,input_atten_mask,loss_mask,labels):
45
bsz, option_num, seq_len = input_ids.shape
46
if self.option_num is not None: assert option_num == self.option_num
48
output=self.model(input_ids=input_ids.reshape(bsz*option_num, seq_len),
49
attention_mask=input_atten_mask.reshape(bsz*option_num, seq_len))
51
logits=output.logits.reshape(bsz, option_num, seq_len, -1)
52
logits=logits[:,:, :-1, :] # (bsz, option_num, seq_len-1, hidden_dim)
53
targets=input_ids[:,:,1:].unsqueeze(-1) # (bsz,option_num, seq_len-1, 1)
54
logit_probs= torch.nn.functional.log_softmax(logits.float(), dim=-1) # (bsz, option_num, seq_len-1,hidden_dim)
55
loss_mask=loss_mask[:,:,1:] # (bsz, option_num, seq_len-1)
56
loss= -torch.gather(logit_probs, -1, targets).squeeze(-1) * loss_mask # (bsz, option_num, seq_len-1)
57
loss = loss.sum(-1) / loss_mask.sum(-1) # (bsz, option_num)
58
preds= torch.argmin(loss,dim=-1)
59
normed_loss = torch.nn.functional.normalize(loss, p=1,dim=-1)
60
labels_losses = torch.gather(normed_loss, -1, labels).squeeze(-1).tolist()
61
accurate_list=(preds==labels.squeeze(-1)).int().tolist()
63
"labels_losses": labels_losses,
64
"accurate_list": accurate_list,
65
"preds": preds.tolist()
68
def completion_losses(self,input_ids,input_atten_mask,labels):
70
answer_start = int(input_atten_mask.shape[-1])
71
res = self.model.generate(input_ids=input_ids.squeeze(1), #remove the dim for option_num
72
attention_mask=input_atten_mask.squeeze(1),
73
eos_token_id=self.dataset_reader.tokenizer.encode("\n")[0],
74
pad_token_id=self.dataset_reader.tokenizer.pad_token_id,
75
max_length=min(self.max_length,answer_start+self.generate_max_len),
78
pred_ids=res[:,answer_start:]
80
for i in range(len(pred_ids)):
81
pred=self.dataset_reader.tokenizer.decode(pred_ids[i],skip_special_tokens=True)
82
# avoid empty prediction to avoid errors when calculating Rouge metric scores
83
if '\n' not in pred: pred+='\n'
85
compute_metric=metric_dict[self.dataset_reader.task.metric]
86
scores=compute_metric(preds=preds, labels=labels, return_list=True)
88
"labels_losses": [1-score for score in scores],
89
"accurate_list": scores,
95
if self.accelerator.is_main_process:
96
dataloader = tqdm.tqdm(self.dataloader)
98
dataloader = self.dataloader
100
with BufferedJsonWriter(f"{self.output_train_file}tmp_{self.accelerator.device}.bin") as buffer:
101
for i,entry in enumerate(dataloader):
102
if "stop" in self.cfg and self.cfg.stop==i: # pass stop for debug
104
metadata = entry.pop("metadata")
105
if self.dataset_reader.task.class_num==1:
106
one_shot_res=self.completion_losses(
107
input_ids=entry.input_ids,
108
input_atten_mask=entry.input_atten_mask,
109
labels=[x.pop('temp_label') for x in metadata],
112
one_shot_res=self.choice_losses(
113
input_ids=entry.input_ids,
114
input_atten_mask=entry.input_atten_mask,
115
loss_mask=entry.input_loss_mask,
118
one_shot_losses=one_shot_res["labels_losses"]
119
for i in range(len(metadata)):
120
metadata[i]['pred']=one_shot_res["preds"][i]
121
metadata[i]['loss']=one_shot_losses[i]
122
metadata[i]['one_shot_acc']=one_shot_res["accurate_list"][i]
123
buffer.write(metadata)
125
def write_results(self):
126
def split_example(entry):
129
for key,val in entry.items():
130
if key.startswith("test_"):
131
test_example[key[len("test_"):]] = val
133
prompt_example[key] = val
134
return test_example,prompt_example
137
for path in glob.glob(f"{self.output_train_file}tmp_*.bin"):
138
with BufferedJsonReader(path) as f:
145
if entry['test_id'] not in example_dict:
146
test_example,prompt_example = split_example(entry)
147
test_example['ctxs'] = [prompt_example]
148
example_dict[entry['test_id']] = test_example
150
_,prompt_example = split_example(entry)
151
example_dict[entry['test_id']]['ctxs'].append(prompt_example)
152
one_shot_true+=prompt_example["one_shot_acc"]
153
overall_one_shot_acc=one_shot_true/len(data)
154
logger.info('task name: %s', self.cfg.task_name)
155
logger.info('one_shot_acc: %f', overall_one_shot_acc)
157
example_list = list(example_dict.values())
158
for entry in example_list:
159
entry['task_name']=self.cfg.task_name
161
# rank loss from low to high, the lower the loss, the higher the efficiency of prompt
162
entry['ctxs'] = sorted(entry['ctxs'],key = lambda x: x['loss'])
164
# check whether the first-ranked prompt can lead to the gold prediction
165
first_rank_true+=entry['ctxs'][0]["one_shot_acc"]
167
logger.info('len(example_list): %d',len(example_list))
168
overall_first_rank_acc=first_rank_true/len(example_list)
169
logger.info('first_rank_acc: %f', overall_first_rank_acc)
171
# split the scored data to 90% : 10% for training and validation respectively
172
random.Random(42).shuffle(example_list)
174
n_train=int(len(example_list)*split_ratio)
175
with open(self.output_train_file,"w") as writer:
176
writer.write(json.dumps(example_list[:n_train], indent=4) + "\n")
177
with open(self.output_valid_file,"w") as writer:
178
writer.write(json.dumps(example_list[n_train:], indent=4) + "\n")
179
for path in glob.glob(f"{self.output_train_file}tmp_*.bin"):
183
@hydra.main(config_path="configs",config_name="scorer")
186
accelerator = Accelerator()
187
scorer = Scorer(cfg, accelerator)
189
accelerator.wait_for_everyone()
190
if accelerator.is_main_process:
191
scorer.write_results()
193
if __name__ == "__main__":