lmops

Форк
0
/
scorer.py 
194 строки · 9.0 Кб
1
import torch
2
import tqdm
3
from torch.utils.data import DataLoader
4
from src.data.collators import DataCollatorWithPaddingAndCuda
5
import hydra.utils as hu 
6
import hydra
7
import json
8
import os
9
from omegaconf import OmegaConf
10
import random 
11
from src.utils.cache_util import BufferedJsonWriter, BufferedJsonReader
12
from src.utils.metric import metric_dict
13
from accelerate import Accelerator
14
import glob
15
import logging
16
from transformers import  AutoModelForCausalLM
17
logger = logging.getLogger(__name__)
18

19

20
class Scorer:
21
    def __init__(self,cfg, accelerator) -> None:
22
        self.dataset_reader = hu.instantiate(cfg.dataset_reader)
23
        self.dataset_reader.shard(accelerator)
24
        co = DataCollatorWithPaddingAndCuda(tokenizer=self.dataset_reader.tokenizer,device=accelerator.device)
25
        self.dataloader = DataLoader(self.dataset_reader,batch_size=cfg.batch_size,collate_fn=co)
26
        self.dataset_reader.tokenizer.pad_token_id = self.dataset_reader.tokenizer.eos_token_id
27

28
        self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=cfg.model_name, cache_dir=cfg.cache_dir)
29
        self.output_train_file = cfg.output_train_file
30
        self.output_valid_file = cfg.output_valid_file
31
        self.accelerator = accelerator
32
        
33
        self.model = self.model.half().to(self.accelerator.device)
34
        self.model = self.model.eval()
35
        self.cfg = cfg
36
        self.tokenizer=self.dataset_reader.tokenizer
37
        self.option_num=self.dataset_reader.task.class_num
38

39
        os.makedirs(os.path.dirname(cfg.output_train_file), exist_ok=True)
40

41
        self.max_length=cfg.max_length #used for text completion task,
42
        self.generate_max_len=cfg.generate_max_len # max seq len to be generated
43
        
44
    def choice_losses(self,input_ids,input_atten_mask,loss_mask,labels):
45
        bsz, option_num, seq_len = input_ids.shape
46
        if self.option_num is not None: assert option_num == self.option_num
47
        with torch.no_grad():
48
            output=self.model(input_ids=input_ids.reshape(bsz*option_num, seq_len), 
49
                              attention_mask=input_atten_mask.reshape(bsz*option_num, seq_len))
50

51
        logits=output.logits.reshape(bsz, option_num, seq_len, -1)            
52
        logits=logits[:,:, :-1, :] # (bsz, option_num, seq_len-1, hidden_dim)
53
        targets=input_ids[:,:,1:].unsqueeze(-1) # (bsz,option_num, seq_len-1, 1)
54
        logit_probs= torch.nn.functional.log_softmax(logits.float(), dim=-1) # (bsz, option_num, seq_len-1,hidden_dim)
55
        loss_mask=loss_mask[:,:,1:] #  (bsz, option_num, seq_len-1)
56
        loss= -torch.gather(logit_probs, -1, targets).squeeze(-1) * loss_mask  #  (bsz, option_num, seq_len-1) 
57
        loss = loss.sum(-1) / loss_mask.sum(-1) # (bsz, option_num)
58
        preds= torch.argmin(loss,dim=-1)
59
        normed_loss = torch.nn.functional.normalize(loss, p=1,dim=-1)
60
        labels_losses = torch.gather(normed_loss, -1, labels).squeeze(-1).tolist()
61
        accurate_list=(preds==labels.squeeze(-1)).int().tolist()
62
        return  {
63
                "labels_losses": labels_losses,
64
                "accurate_list": accurate_list,
65
                "preds": preds.tolist()
66
                }
67

68
    def completion_losses(self,input_ids,input_atten_mask,labels):
69
        with torch.no_grad():
70
            answer_start = int(input_atten_mask.shape[-1]) 
71
            res = self.model.generate(input_ids=input_ids.squeeze(1), #remove the dim for option_num
72
                                        attention_mask=input_atten_mask.squeeze(1),
73
                                        eos_token_id=self.dataset_reader.tokenizer.encode("\n")[0],
74
                                        pad_token_id=self.dataset_reader.tokenizer.pad_token_id,
75
                                        max_length=min(self.max_length,answer_start+self.generate_max_len),
76
                                        do_sample=False)
77
                        
78
        pred_ids=res[:,answer_start:]
79
        preds=[]
80
        for i in range(len(pred_ids)):
81
            pred=self.dataset_reader.tokenizer.decode(pred_ids[i],skip_special_tokens=True)
82
            # avoid empty prediction to avoid errors when calculating Rouge metric scores
83
            if '\n' not in pred: pred+='\n' 
84
            preds.append(pred)
85
        compute_metric=metric_dict[self.dataset_reader.task.metric]
86
        scores=compute_metric(preds=preds, labels=labels, return_list=True)
87
        return  {
88
                "labels_losses": [1-score for score in scores],
89
                "accurate_list": scores,
90
                "preds": preds
91
                }
92
    
93
    def forward(self):
94
        
95
        if self.accelerator.is_main_process:
96
            dataloader = tqdm.tqdm(self.dataloader)
97
        else:
98
            dataloader = self.dataloader
99

100
        with BufferedJsonWriter(f"{self.output_train_file}tmp_{self.accelerator.device}.bin") as buffer:
101
            for i,entry in enumerate(dataloader):
102
                if "stop" in self.cfg and self.cfg.stop==i: # pass stop for debug
103
                    break
104
                metadata = entry.pop("metadata")
105
                if self.dataset_reader.task.class_num==1:
106
                    one_shot_res=self.completion_losses(
107
                                                    input_ids=entry.input_ids,
108
                                                    input_atten_mask=entry.input_atten_mask,
109
                                                    labels=[x.pop('temp_label') for x in metadata],
110
                                                    )
111
                else:
112
                    one_shot_res=self.choice_losses(
113
                                                    input_ids=entry.input_ids,
114
                                                    input_atten_mask=entry.input_atten_mask,
115
                                                    loss_mask=entry.input_loss_mask,
116
                                                    labels=entry.labels,
117
                                                    )
118
                one_shot_losses=one_shot_res["labels_losses"]
119
                for i in range(len(metadata)):
120
                    metadata[i]['pred']=one_shot_res["preds"][i]
121
                    metadata[i]['loss']=one_shot_losses[i]
122
                    metadata[i]['one_shot_acc']=one_shot_res["accurate_list"][i]
123
                buffer.write(metadata)
124

125
    def write_results(self):
126
        def split_example(entry):
127
            test_example = {}
128
            prompt_example = {}
129
            for key,val in entry.items():
130
                if key.startswith("test_"):
131
                    test_example[key[len("test_"):]] = val
132
                else:
133
                    prompt_example[key] = val
134
            return test_example,prompt_example
135
        
136
        data = []
137
        for path in glob.glob(f"{self.output_train_file}tmp_*.bin"):
138
            with BufferedJsonReader(path) as f:
139
                for x in f.read():
140
                    data.extend(x) 
141

142
        example_dict = {}
143
        one_shot_true=0
144
        for entry in data:
145
            if entry['test_id'] not in example_dict:
146
                test_example,prompt_example = split_example(entry)
147
                test_example['ctxs'] = [prompt_example]
148
                example_dict[entry['test_id']] = test_example
149
            else:
150
                _,prompt_example = split_example(entry)
151
                example_dict[entry['test_id']]['ctxs'].append(prompt_example)
152
            one_shot_true+=prompt_example["one_shot_acc"]
153
        overall_one_shot_acc=one_shot_true/len(data)
154
        logger.info('task name: %s', self.cfg.task_name)
155
        logger.info('one_shot_acc: %f', overall_one_shot_acc)
156
        first_rank_true=0
157
        example_list = list(example_dict.values())
158
        for entry in example_list:
159
            entry['task_name']=self.cfg.task_name
160

161
            # rank loss from low to high, the lower the loss, the higher the efficiency of prompt
162
            entry['ctxs'] = sorted(entry['ctxs'],key = lambda x: x['loss']) 
163

164
            # check whether the first-ranked prompt can lead to the gold prediction
165
            first_rank_true+=entry['ctxs'][0]["one_shot_acc"]
166

167
        logger.info('len(example_list): %d',len(example_list))
168
        overall_first_rank_acc=first_rank_true/len(example_list)
169
        logger.info('first_rank_acc: %f', overall_first_rank_acc)
170

171
        # split the scored data to 90% : 10% for training and validation respectively
172
        random.Random(42).shuffle(example_list)
173
        split_ratio=0.9
174
        n_train=int(len(example_list)*split_ratio)
175
        with open(self.output_train_file,"w") as writer:
176
            writer.write(json.dumps(example_list[:n_train], indent=4) + "\n")
177
        with open(self.output_valid_file,"w") as writer:
178
            writer.write(json.dumps(example_list[n_train:], indent=4) + "\n")
179
        for path in glob.glob(f"{self.output_train_file}tmp_*.bin"):
180
            os.remove(path)
181

182

183
@hydra.main(config_path="configs",config_name="scorer")
184
def main(cfg):
185
    logger.info(cfg)
186
    accelerator = Accelerator()
187
    scorer = Scorer(cfg, accelerator)
188
    scorer.forward()
189
    accelerator.wait_for_everyone()
190
    if accelerator.is_main_process:
191
        scorer.write_results()
192

193
if __name__ == "__main__":
194
    main()
195

196

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.