MetaGPT

Форк
0
201 строка · 7.3 Кб
1
import asyncio
2
from typing import List, Tuple, Union
3

4
import evaluate
5
import jieba
6
from llama_index.core.embeddings import BaseEmbedding
7
from llama_index.core.evaluation import SemanticSimilarityEvaluator
8
from llama_index.core.schema import NodeWithScore
9
from pydantic import BaseModel
10

11
from metagpt.const import EXAMPLE_BENCHMARK_PATH
12
from metagpt.logs import logger
13
from metagpt.rag.factories import get_rag_embedding
14
from metagpt.utils.common import read_json_file
15

16

17
class DatasetInfo(BaseModel):
18
    name: str
19
    document_files: List[str]
20
    gt_info: List[dict]
21

22

23
class DatasetConfig(BaseModel):
24
    datasets: List[DatasetInfo]
25

26

27
class RAGBenchmark:
28
    def __init__(
29
        self,
30
        embed_model: BaseEmbedding = None,
31
    ):
32
        self.evaluator = SemanticSimilarityEvaluator(
33
            embed_model=embed_model or get_rag_embedding(),
34
        )
35

36
    def set_metrics(
37
        self,
38
        bleu_avg: float = 0.0,
39
        bleu_1: float = 0.0,
40
        bleu_2: float = 0.0,
41
        bleu_3: float = 0.0,
42
        bleu_4: float = 0.0,
43
        rouge_l: float = 0.0,
44
        semantic_similarity: float = 0.0,
45
        recall: float = 0.0,
46
        hit_rate: float = 0.0,
47
        mrr: float = 0.0,
48
        length: float = 0.0,
49
        generated_text: str = None,
50
        ground_truth_text: str = None,
51
        question: str = None,
52
    ):
53
        metrics = {
54
            "bleu-avg": bleu_avg,
55
            "bleu-1": bleu_1,
56
            "bleu-2": bleu_2,
57
            "bleu-3": bleu_3,
58
            "bleu-4": bleu_4,
59
            "rouge-L": rouge_l,
60
            "semantic similarity": semantic_similarity,
61
            "recall": recall,
62
            "hit_rate": hit_rate,
63
            "mrr": mrr,
64
            "length": length,
65
        }
66

67
        log = {
68
            "generated_text": generated_text,
69
            "ground_truth_text": ground_truth_text,
70
            "question": question,
71
        }
72

73
        return {"metrics": metrics, "log": log}
74

75
    def bleu_score(self, response: str, reference: str, with_penalty=False) -> Union[float, Tuple[float]]:
76
        f = lambda text: list(jieba.cut(text))
77
        bleu = evaluate.load(path="bleu")
78
        results = bleu.compute(predictions=[response], references=[[reference]], tokenizer=f)
79

80
        bleu_avg = results["bleu"]
81
        bleu1 = results["precisions"][0]
82
        bleu2 = results["precisions"][1]
83
        bleu3 = results["precisions"][2]
84
        bleu4 = results["precisions"][3]
85
        brevity_penalty = results["brevity_penalty"]
86

87
        if with_penalty:
88
            return bleu_avg, bleu1, bleu2, bleu3, bleu4
89
        else:
90
            return 0.0 if brevity_penalty == 0 else bleu_avg / brevity_penalty, bleu1, bleu2, bleu3, bleu4
91

92
    def rougel_score(self, response: str, reference: str) -> float:
93
        # pip install rouge_score
94
        f = lambda text: list(jieba.cut(text))
95
        rouge = evaluate.load(path="rouge")
96

97
        results = rouge.compute(predictions=[response], references=[[reference]], tokenizer=f, rouge_types=["rougeL"])
98
        score = results["rougeL"]
99
        return score
100

101
    def recall(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
102
        if nodes:
103
            total_recall = sum(any(node.text in doc for node in nodes) for doc in reference_docs)
104
            return total_recall / len(reference_docs)
105
        else:
106
            return 0.0
107

108
    def hit_rate(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
109
        if nodes:
110
            return 1.0 if any(node.text in doc for doc in reference_docs for node in nodes) else 0.0
111
        else:
112
            return 0.0
113

114
    def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
115
        mrr_sum = 0.0
116

117
        for i, node in enumerate(nodes, start=1):
118
            for doc in reference_docs:
119
                if text in doc:
120
                    mrr_sum += 1.0 / i
121
                    return mrr_sum
122

123
        return mrr_sum
124
        
125
    async def semantic_similarity(self, response: str, reference: str) -> float:
126
        result = await self.evaluator.aevaluate(
127
            response=response,
128
            reference=reference,
129
        )
130

131
        return result.score
132

133
    async def compute_metric(
134
        self,
135
        response: str = None,
136
        reference: str = None,
137
        nodes: list[NodeWithScore] = None,
138
        reference_doc: list[str] = None,
139
        question: str = None,
140
    ):
141
        recall = self.recall(nodes, reference_doc)
142
        bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference)
143
        rouge_l = self.rougel_score(response, reference)
144
        hit_rate = self.hit_rate(nodes, reference_doc)
145
        mrr = self.mean_reciprocal_rank(nodes, reference_doc)
146

147
        similarity = await self.semantic_similarity(response, reference)
148

149
        result = self.set_metrics(
150
            bleu_avg,
151
            bleu1,
152
            bleu2,
153
            bleu3,
154
            bleu4,
155
            rouge_l,
156
            similarity,
157
            recall,
158
            hit_rate,
159
            mrr,
160
            len(response),
161
            response,
162
            reference,
163
            question,
164
        )
165

166
        return result
167

168
    @staticmethod
169
    def load_dataset(ds_names: list[str] = ["all"]):
170
        infos = read_json_file((EXAMPLE_BENCHMARK_PATH / "dataset_info.json").as_posix())
171
        dataset_config = DatasetConfig(
172
            datasets=[
173
                DatasetInfo(
174
                    name=name,
175
                    document_files=[
176
                        (EXAMPLE_BENCHMARK_PATH / name / file).as_posix() for file in info["document_file"]
177
                    ],
178
                    gt_info=read_json_file((EXAMPLE_BENCHMARK_PATH / name / info["gt_file"]).as_posix()),
179
                )
180
                for dataset_info in infos
181
                for name, info in dataset_info.items()
182
                if name in ds_names or "all" in ds_names
183
            ]
184
        )
185

186
        return dataset_config
187

188

189
if __name__ == "__main__":
190
    benchmark = RAGBenchmark()
191
    answer = "是的,根据提供的信息,2023年7月20日,应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作,应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金,中央财政也会提供适当的补助。救助资金将通过专账管理,并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏,并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息,并建立档案。此外,《规范》还强调了资金发放的具体方式和公开透明的要求。"
192
    ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
193
    bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth)
194
    rougeL_score = benchmark.rougel_score(answer, ground_truth)
195
    similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
196

197
    logger.info(
198
        f"BLEU Scores: bleu_avg = {bleu_avg}, bleu1 = {bleu1}, bleu2 = {bleu2}, bleu3 = {bleu3}, bleu4 = {bleu4}, "
199
        f"RougeL Score: {rougeL_score}, "
200
        f"Semantic Similarity: {similarity}"
201
    )
202

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.