aurora

Форк
0
/
evaluator.py 
124 строки · 6.0 Кб
1
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
2

3
import os
4
import json
5
import torch
6
import inspect
7
import tiktoken
8
import numpy as np
9
from tqdm import tqdm, trange
10
from typing import Any, Dict, List, Optional
11

12
from datasets import load_dataset
13
from transformers.utils import cached_file
14

15
from llmtuner.data.template import get_template_and_fix_tokenizer
16
from llmtuner.eval.template import get_eval_template
17
from llmtuner.extras.constants import CHOICES, SUBJECTS
18
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
19

20

21
class Evaluator:
22

23
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
24
        self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
25
        self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
26
        self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
27
        self.model = dispatch_model(self.model)
28
        self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
29
        self.eval_template = get_eval_template(self.eval_args.lang)
30
        self.choice_inputs = self._encode_choices()
31

32
    def _encode_choices(self) -> List[int]:
33
        if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
34
            kwargs = dict(allowed_special="all")
35
        else:
36
            kwargs = dict(add_special_tokens=False)
37

38
        return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
39

40
    @torch.inference_mode()
41
    def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
42
        logits = self.model(**batch_input).logits
43
        lengths = torch.sum(batch_input["attention_mask"], dim=-1)
44
        word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
45
        choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
46
        return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
47

48
    def eval(self) -> None:
49
        if "token" in inspect.signature(cached_file).parameters:
50
            kwargs = {"token": self.model_args.hf_hub_token}
51
        elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
52
            kwargs = {"use_auth_token": self.model_args.hf_hub_token}
53

54
        mapping = cached_file(
55
            path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
56
            filename="mapping.json",
57
            cache_dir=self.model_args.cache_dir,
58
            **kwargs
59
        )
60

61
        with open(mapping, "r", encoding="utf-8") as f:
62
            categorys: Dict[str, Dict[str, str]] = json.load(f)
63

64
        category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
65
        pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
66
        results = {}
67
        for subject in pbar:
68
            dataset = load_dataset(
69
                path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
70
                name=subject,
71
                cache_dir=self.model_args.cache_dir,
72
                download_mode=self.eval_args.download_mode,
73
                token=self.model_args.hf_hub_token
74
            )
75
            pbar.set_postfix_str(categorys[subject]["name"])
76
            inputs, outputs, labels = [], [], []
77
            for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
78
                support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
79
                query, resp, history = self.eval_template.format_example(
80
                    target_data=dataset[self.data_args.split][i],
81
                    support_set=support_set,
82
                    subject_name=categorys[subject]["name"],
83
                    use_history=self.template.use_history
84
                )
85
                input_ids, _ = self.template.encode_oneturn(
86
                    tokenizer=self.tokenizer, query=query, resp=resp, history=history
87
                )
88
                inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
89
                labels.append(resp)
90

91
            for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
92
                batch_input = self.tokenizer.pad(
93
                    inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
94
                ).to(self.model.device)
95
                preds = self.batch_inference(batch_input)
96
                outputs += preds
97

98
            corrects = (np.array(outputs) == np.array(labels))
99
            category_name = categorys[subject]["category"]
100
            category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
101
            category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
102
            results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
103

104
        pbar.close()
105
        self._save_results(category_corrects, results)
106

107
    def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
108
        score_info = "\n".join([
109
            "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
110
            for category_name, category_correct in category_corrects.items() if len(category_correct)
111
        ])
112
        print(score_info)
113
        if self.eval_args.save_dir is not None:
114
            os.makedirs(self.eval_args.save_dir, exist_ok=False)
115
            with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
116
                json.dump(results, f, indent=2)
117

118
            with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
119
                f.write(score_info)
120

121

122
if __name__ == "__main__":
123
    evaluator = Evaluator()
124
    evaluator.eval()
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.