llama-index

Форк
0
212 строк · 7.4 Кб
1
import json
2
import os
3
import re
4
import string
5
from collections import Counter
6
from shutil import rmtree
7
from typing import Any, Dict, List, Optional, Tuple
8

9
import requests
10
import tqdm
11

12
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
13
from llama_index.legacy.core.base_retriever import BaseRetriever
14
from llama_index.legacy.query_engine.retriever_query_engine import RetrieverQueryEngine
15
from llama_index.legacy.schema import NodeWithScore, QueryBundle, TextNode
16
from llama_index.legacy.utils import get_cache_dir
17

18
DEV_DISTRACTOR_URL = """http://curtis.ml.cmu.edu/datasets/\
19
hotpot/hotpot_dev_distractor_v1.json"""
20

21

22
class HotpotQAEvaluator:
23
    """
24
    Refer to https://hotpotqa.github.io/ for more details on the dataset.
25
    """
26

27
    def _download_datasets(self) -> Dict[str, str]:
28
        cache_dir = get_cache_dir()
29

30
        dataset_paths = {}
31
        dataset = "hotpot_dev_distractor"
32
        dataset_full_path = os.path.join(cache_dir, "datasets", "HotpotQA")
33
        if not os.path.exists(dataset_full_path):
34
            url = DEV_DISTRACTOR_URL
35
            try:
36
                os.makedirs(dataset_full_path, exist_ok=True)
37
                save_file = open(
38
                    os.path.join(dataset_full_path, "dev_distractor.json"), "wb"
39
                )
40
                response = requests.get(url, stream=True)
41

42
                # Define the size of each chunk
43
                chunk_size = 1024
44

45
                # Loop over the chunks and parse the JSON data
46
                for chunk in tqdm.tqdm(response.iter_content(chunk_size=chunk_size)):
47
                    if chunk:
48
                        save_file.write(chunk)
49
            except Exception as e:
50
                if os.path.exists(dataset_full_path):
51
                    print(
52
                        "Dataset:", dataset, "not found at:", url, "Removing cached dir"
53
                    )
54
                    rmtree(dataset_full_path)
55
                raise ValueError(f"could not download {dataset} dataset") from e
56
        dataset_paths[dataset] = os.path.join(dataset_full_path, "dev_distractor.json")
57
        print("Dataset:", dataset, "downloaded at:", dataset_full_path)
58
        return dataset_paths
59

60
    def run(
61
        self,
62
        query_engine: BaseQueryEngine,
63
        queries: int = 10,
64
        queries_fraction: Optional[float] = None,
65
        show_result: bool = False,
66
    ) -> None:
67
        dataset_paths = self._download_datasets()
68
        dataset = "hotpot_dev_distractor"
69
        dataset_path = dataset_paths[dataset]
70
        print("Evaluating on dataset:", dataset)
71
        print("-------------------------------------")
72

73
        f = open(dataset_path)
74
        query_objects = json.loads(f.read())
75
        if queries_fraction:
76
            queries_to_load = int(len(query_objects) * queries_fraction)
77
        else:
78
            queries_to_load = queries
79
            queries_fraction = round(queries / len(query_objects), 5)
80

81
        print(
82
            f"Loading {queries_to_load} queries out of \
83
{len(query_objects)} (fraction: {queries_fraction})"
84
        )
85
        query_objects = query_objects[:queries_to_load]
86

87
        assert isinstance(
88
            query_engine, RetrieverQueryEngine
89
        ), "query_engine must be a RetrieverQueryEngine for this evaluation"
90
        retriever = HotpotQARetriever(query_objects)
91
        # Mock the query engine with a retriever
92
        query_engine = query_engine.with_retriever(retriever=retriever)
93

94
        scores = {"exact_match": 0.0, "f1": 0.0}
95

96
        for query in query_objects:
97
            query_bundle = QueryBundle(
98
                query_str=query["question"]
99
                + " Give a short factoid answer (as few words as possible).",
100
                custom_embedding_strs=[query["question"]],
101
            )
102
            response = query_engine.query(query_bundle)
103
            em = int(
104
                exact_match_score(
105
                    prediction=str(response), ground_truth=query["answer"]
106
                )
107
            )
108
            f1, _, _ = f1_score(prediction=str(response), ground_truth=query["answer"])
109
            scores["exact_match"] += em
110
            scores["f1"] += f1
111
            if show_result:
112
                print("Question: ", query["question"])
113
                print("Response:", response)
114
                print("Correct answer: ", query["answer"])
115
                print("EM:", em, "F1:", f1)
116
                print("-------------------------------------")
117

118
        for score in scores:
119
            scores[score] /= len(query_objects)
120

121
        print("Scores: ", scores)
122

123

124
class HotpotQARetriever(BaseRetriever):
125
    """
126
    This is a mocked retriever for HotpotQA dataset. It is only meant to be used
127
    with the hotpotqa dev dataset in the distractor setting. This is the setting that
128
    does not require retrieval but requires identifying the supporting facts from
129
    a list of 10 sources.
130
    """
131

132
    def __init__(self, query_objects: Any) -> None:
133
        assert isinstance(
134
            query_objects,
135
            list,
136
        ), f"query_objects must be a list, got: {type(query_objects)}"
137
        self._queries = {}
138
        for object in query_objects:
139
            self._queries[object["question"]] = object
140

141
    def _retrieve(self, query: QueryBundle) -> List[NodeWithScore]:
142
        if query.custom_embedding_strs:
143
            query_str = query.custom_embedding_strs[0]
144
        else:
145
            query_str = query.query_str
146
        contexts = self._queries[query_str]["context"]
147
        node_with_scores = []
148
        for ctx in contexts:
149
            text_list = ctx[1]
150
            text = "\n".join(text_list)
151
            node = TextNode(text=text, metadata={"title": ctx[0]})
152
            node_with_scores.append(NodeWithScore(node=node, score=1.0))
153

154
        return node_with_scores
155

156
    def __str__(self) -> str:
157
        return "HotpotQARetriever"
158

159

160
"""
161
Utils from https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
162
"""
163

164

165
def normalize_answer(s: str) -> str:
166
    def remove_articles(text: str) -> str:
167
        return re.sub(r"\b(a|an|the)\b", " ", text)
168

169
    def white_space_fix(text: str) -> str:
170
        return " ".join(text.split())
171

172
    def remove_punc(text: str) -> str:
173
        exclude = set(string.punctuation)
174
        return "".join(ch for ch in text if ch not in exclude)
175

176
    def lower(text: str) -> str:
177
        return text.lower()
178

179
    return white_space_fix(remove_articles(remove_punc(lower(s))))
180

181

182
def f1_score(prediction: str, ground_truth: str) -> Tuple[float, float, float]:
183
    normalized_prediction = normalize_answer(prediction)
184
    normalized_ground_truth = normalize_answer(ground_truth)
185

186
    ZERO_METRIC = (0, 0, 0)
187

188
    if (
189
        normalized_prediction in ["yes", "no", "noanswer"]
190
        and normalized_prediction != normalized_ground_truth
191
    ):
192
        return ZERO_METRIC
193
    if (
194
        normalized_ground_truth in ["yes", "no", "noanswer"]
195
        and normalized_prediction != normalized_ground_truth
196
    ):
197
        return ZERO_METRIC
198

199
    prediction_tokens = normalized_prediction.split()
200
    ground_truth_tokens = normalized_ground_truth.split()
201
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
202
    num_same = sum(common.values())
203
    if num_same == 0:
204
        return ZERO_METRIC
205
    precision = 1.0 * num_same / len(prediction_tokens)
206
    recall = 1.0 * num_same / len(ground_truth_tokens)
207
    f1 = (2 * precision * recall) / (precision + recall)
208
    return f1, precision, recall
209

210

211
def exact_match_score(prediction: str, ground_truth: str) -> bool:
212
    return normalize_answer(prediction) == normalize_answer(ground_truth)
213

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.