llama-index
212 строк · 7.4 Кб
1import json2import os3import re4import string5from collections import Counter6from shutil import rmtree7from typing import Any, Dict, List, Optional, Tuple8
9import requests10import tqdm11
12from llama_index.legacy.core.base_query_engine import BaseQueryEngine13from llama_index.legacy.core.base_retriever import BaseRetriever14from llama_index.legacy.query_engine.retriever_query_engine import RetrieverQueryEngine15from llama_index.legacy.schema import NodeWithScore, QueryBundle, TextNode16from llama_index.legacy.utils import get_cache_dir17
18DEV_DISTRACTOR_URL = """http://curtis.ml.cmu.edu/datasets/\19hotpot/hotpot_dev_distractor_v1.json"""
20
21
22class HotpotQAEvaluator:23"""24Refer to https://hotpotqa.github.io/ for more details on the dataset.
25"""
26
27def _download_datasets(self) -> Dict[str, str]:28cache_dir = get_cache_dir()29
30dataset_paths = {}31dataset = "hotpot_dev_distractor"32dataset_full_path = os.path.join(cache_dir, "datasets", "HotpotQA")33if not os.path.exists(dataset_full_path):34url = DEV_DISTRACTOR_URL35try:36os.makedirs(dataset_full_path, exist_ok=True)37save_file = open(38os.path.join(dataset_full_path, "dev_distractor.json"), "wb"39)40response = requests.get(url, stream=True)41
42# Define the size of each chunk43chunk_size = 102444
45# Loop over the chunks and parse the JSON data46for chunk in tqdm.tqdm(response.iter_content(chunk_size=chunk_size)):47if chunk:48save_file.write(chunk)49except Exception as e:50if os.path.exists(dataset_full_path):51print(52"Dataset:", dataset, "not found at:", url, "Removing cached dir"53)54rmtree(dataset_full_path)55raise ValueError(f"could not download {dataset} dataset") from e56dataset_paths[dataset] = os.path.join(dataset_full_path, "dev_distractor.json")57print("Dataset:", dataset, "downloaded at:", dataset_full_path)58return dataset_paths59
60def run(61self,62query_engine: BaseQueryEngine,63queries: int = 10,64queries_fraction: Optional[float] = None,65show_result: bool = False,66) -> None:67dataset_paths = self._download_datasets()68dataset = "hotpot_dev_distractor"69dataset_path = dataset_paths[dataset]70print("Evaluating on dataset:", dataset)71print("-------------------------------------")72
73f = open(dataset_path)74query_objects = json.loads(f.read())75if queries_fraction:76queries_to_load = int(len(query_objects) * queries_fraction)77else:78queries_to_load = queries79queries_fraction = round(queries / len(query_objects), 5)80
81print(82f"Loading {queries_to_load} queries out of \83{len(query_objects)} (fraction: {queries_fraction})"84)85query_objects = query_objects[:queries_to_load]86
87assert isinstance(88query_engine, RetrieverQueryEngine89), "query_engine must be a RetrieverQueryEngine for this evaluation"90retriever = HotpotQARetriever(query_objects)91# Mock the query engine with a retriever92query_engine = query_engine.with_retriever(retriever=retriever)93
94scores = {"exact_match": 0.0, "f1": 0.0}95
96for query in query_objects:97query_bundle = QueryBundle(98query_str=query["question"]99+ " Give a short factoid answer (as few words as possible).",100custom_embedding_strs=[query["question"]],101)102response = query_engine.query(query_bundle)103em = int(104exact_match_score(105prediction=str(response), ground_truth=query["answer"]106)107)108f1, _, _ = f1_score(prediction=str(response), ground_truth=query["answer"])109scores["exact_match"] += em110scores["f1"] += f1111if show_result:112print("Question: ", query["question"])113print("Response:", response)114print("Correct answer: ", query["answer"])115print("EM:", em, "F1:", f1)116print("-------------------------------------")117
118for score in scores:119scores[score] /= len(query_objects)120
121print("Scores: ", scores)122
123
124class HotpotQARetriever(BaseRetriever):125"""126This is a mocked retriever for HotpotQA dataset. It is only meant to be used
127with the hotpotqa dev dataset in the distractor setting. This is the setting that
128does not require retrieval but requires identifying the supporting facts from
129a list of 10 sources.
130"""
131
132def __init__(self, query_objects: Any) -> None:133assert isinstance(134query_objects,135list,136), f"query_objects must be a list, got: {type(query_objects)}"137self._queries = {}138for object in query_objects:139self._queries[object["question"]] = object140
141def _retrieve(self, query: QueryBundle) -> List[NodeWithScore]:142if query.custom_embedding_strs:143query_str = query.custom_embedding_strs[0]144else:145query_str = query.query_str146contexts = self._queries[query_str]["context"]147node_with_scores = []148for ctx in contexts:149text_list = ctx[1]150text = "\n".join(text_list)151node = TextNode(text=text, metadata={"title": ctx[0]})152node_with_scores.append(NodeWithScore(node=node, score=1.0))153
154return node_with_scores155
156def __str__(self) -> str:157return "HotpotQARetriever"158
159
160"""
161Utils from https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
162"""
163
164
165def normalize_answer(s: str) -> str:166def remove_articles(text: str) -> str:167return re.sub(r"\b(a|an|the)\b", " ", text)168
169def white_space_fix(text: str) -> str:170return " ".join(text.split())171
172def remove_punc(text: str) -> str:173exclude = set(string.punctuation)174return "".join(ch for ch in text if ch not in exclude)175
176def lower(text: str) -> str:177return text.lower()178
179return white_space_fix(remove_articles(remove_punc(lower(s))))180
181
182def f1_score(prediction: str, ground_truth: str) -> Tuple[float, float, float]:183normalized_prediction = normalize_answer(prediction)184normalized_ground_truth = normalize_answer(ground_truth)185
186ZERO_METRIC = (0, 0, 0)187
188if (189normalized_prediction in ["yes", "no", "noanswer"]190and normalized_prediction != normalized_ground_truth191):192return ZERO_METRIC193if (194normalized_ground_truth in ["yes", "no", "noanswer"]195and normalized_prediction != normalized_ground_truth196):197return ZERO_METRIC198
199prediction_tokens = normalized_prediction.split()200ground_truth_tokens = normalized_ground_truth.split()201common = Counter(prediction_tokens) & Counter(ground_truth_tokens)202num_same = sum(common.values())203if num_same == 0:204return ZERO_METRIC205precision = 1.0 * num_same / len(prediction_tokens)206recall = 1.0 * num_same / len(ground_truth_tokens)207f1 = (2 * precision * recall) / (precision + recall)208return f1, precision, recall209
210
211def exact_match_score(prediction: str, ground_truth: str) -> bool:212return normalize_answer(prediction) == normalize_answer(ground_truth)213