llama-index
110 строк · 4.1 Кб
1import os2from shutil import rmtree3from typing import Callable, Dict, List, Optional4
5import tqdm6
7from llama_index.legacy.core.base_retriever import BaseRetriever8from llama_index.legacy.postprocessor.types import BaseNodePostprocessor9from llama_index.legacy.schema import Document, QueryBundle10from llama_index.legacy.utils import get_cache_dir11
12
13class BeirEvaluator:14"""15Refer to: https://github.com/beir-cellar/beir for a full list of supported datasets
16and a full description of BEIR.
17"""
18
19def __init__(self) -> None:20try:21pass22except ImportError:23raise ImportError(24"Please install beir to use this feature: " "`pip install beir`",25)26
27def _download_datasets(self, datasets: List[str] = ["nfcorpus"]) -> Dict[str, str]:28from beir import util29
30cache_dir = get_cache_dir()31
32dataset_paths = {}33for dataset in datasets:34dataset_full_path = os.path.join(cache_dir, "datasets", "BeIR__" + dataset)35if not os.path.exists(dataset_full_path):36url = f"""https://public.ukp.informatik.tu-darmstadt.de/thakur\37/BEIR/datasets/{dataset}.zip"""38try:39util.download_and_unzip(url, dataset_full_path)40except Exception as e:41print(42"Dataset:", dataset, "not found at:", url, "Removing cached dir"43)44rmtree(dataset_full_path)45raise ValueError(f"invalid BEIR dataset: {dataset}") from e46
47print("Dataset:", dataset, "downloaded at:", dataset_full_path)48dataset_paths[dataset] = os.path.join(dataset_full_path, dataset)49return dataset_paths50
51def run(52self,53create_retriever: Callable[[List[Document]], BaseRetriever],54datasets: List[str] = ["nfcorpus"],55metrics_k_values: List[int] = [3, 10],56node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,57) -> None:58from beir.datasets.data_loader import GenericDataLoader59from beir.retrieval.evaluation import EvaluateRetrieval60
61dataset_paths = self._download_datasets(datasets)62for dataset in datasets:63dataset_path = dataset_paths[dataset]64print("Evaluating on dataset:", dataset)65print("-------------------------------------")66
67corpus, queries, qrels = GenericDataLoader(data_folder=dataset_path).load(68split="test"69)70
71documents = []72for id, val in corpus.items():73doc = Document(74text=val["text"], metadata={"title": val["title"], "doc_id": id}75)76documents.append(doc)77
78retriever = create_retriever(documents)79
80print("Retriever created for: ", dataset)81
82print("Evaluating retriever on questions against qrels")83
84results = {}85for key, query in tqdm.tqdm(queries.items()):86nodes_with_score = retriever.retrieve(query)87node_postprocessors = node_postprocessors or []88for node_postprocessor in node_postprocessors:89nodes_with_score = node_postprocessor.postprocess_nodes(90nodes_with_score, query_bundle=QueryBundle(query_str=query)91)92results[key] = {93node.node.metadata["doc_id"]: node.score94for node in nodes_with_score95}96
97ndcg, map_, recall, precision = EvaluateRetrieval.evaluate(98qrels, results, metrics_k_values99)100print("Results for:", dataset)101for k in metrics_k_values:102print(103{104f"NDCG@{k}": ndcg[f"NDCG@{k}"],105f"MAP@{k}": map_[f"MAP@{k}"],106f"Recall@{k}": recall[f"Recall@{k}"],107f"precision@{k}": precision[f"P@{k}"],108}109)110print("-------------------------------------")111