llama-index

Форк
0
1
import os
2
from shutil import rmtree
3
from typing import Callable, Dict, List, Optional
4

5
import tqdm
6

7
from llama_index.legacy.core.base_retriever import BaseRetriever
8
from llama_index.legacy.postprocessor.types import BaseNodePostprocessor
9
from llama_index.legacy.schema import Document, QueryBundle
10
from llama_index.legacy.utils import get_cache_dir
11

12

13
class BeirEvaluator:
14
    """
15
    Refer to: https://github.com/beir-cellar/beir for a full list of supported datasets
16
    and a full description of BEIR.
17
    """
18

19
    def __init__(self) -> None:
20
        try:
21
            pass
22
        except ImportError:
23
            raise ImportError(
24
                "Please install beir to use this feature: " "`pip install beir`",
25
            )
26

27
    def _download_datasets(self, datasets: List[str] = ["nfcorpus"]) -> Dict[str, str]:
28
        from beir import util
29

30
        cache_dir = get_cache_dir()
31

32
        dataset_paths = {}
33
        for dataset in datasets:
34
            dataset_full_path = os.path.join(cache_dir, "datasets", "BeIR__" + dataset)
35
            if not os.path.exists(dataset_full_path):
36
                url = f"""https://public.ukp.informatik.tu-darmstadt.de/thakur\
37
/BEIR/datasets/{dataset}.zip"""
38
                try:
39
                    util.download_and_unzip(url, dataset_full_path)
40
                except Exception as e:
41
                    print(
42
                        "Dataset:", dataset, "not found at:", url, "Removing cached dir"
43
                    )
44
                    rmtree(dataset_full_path)
45
                    raise ValueError(f"invalid BEIR dataset: {dataset}") from e
46

47
            print("Dataset:", dataset, "downloaded at:", dataset_full_path)
48
            dataset_paths[dataset] = os.path.join(dataset_full_path, dataset)
49
        return dataset_paths
50

51
    def run(
52
        self,
53
        create_retriever: Callable[[List[Document]], BaseRetriever],
54
        datasets: List[str] = ["nfcorpus"],
55
        metrics_k_values: List[int] = [3, 10],
56
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
57
    ) -> None:
58
        from beir.datasets.data_loader import GenericDataLoader
59
        from beir.retrieval.evaluation import EvaluateRetrieval
60

61
        dataset_paths = self._download_datasets(datasets)
62
        for dataset in datasets:
63
            dataset_path = dataset_paths[dataset]
64
            print("Evaluating on dataset:", dataset)
65
            print("-------------------------------------")
66

67
            corpus, queries, qrels = GenericDataLoader(data_folder=dataset_path).load(
68
                split="test"
69
            )
70

71
            documents = []
72
            for id, val in corpus.items():
73
                doc = Document(
74
                    text=val["text"], metadata={"title": val["title"], "doc_id": id}
75
                )
76
                documents.append(doc)
77

78
            retriever = create_retriever(documents)
79

80
            print("Retriever created for: ", dataset)
81

82
            print("Evaluating retriever on questions against qrels")
83

84
            results = {}
85
            for key, query in tqdm.tqdm(queries.items()):
86
                nodes_with_score = retriever.retrieve(query)
87
                node_postprocessors = node_postprocessors or []
88
                for node_postprocessor in node_postprocessors:
89
                    nodes_with_score = node_postprocessor.postprocess_nodes(
90
                        nodes_with_score, query_bundle=QueryBundle(query_str=query)
91
                    )
92
                results[key] = {
93
                    node.node.metadata["doc_id"]: node.score
94
                    for node in nodes_with_score
95
                }
96

97
            ndcg, map_, recall, precision = EvaluateRetrieval.evaluate(
98
                qrels, results, metrics_k_values
99
            )
100
            print("Results for:", dataset)
101
            for k in metrics_k_values:
102
                print(
103
                    {
104
                        f"NDCG@{k}": ndcg[f"NDCG@{k}"],
105
                        f"MAP@{k}": map_[f"MAP@{k}"],
106
                        f"Recall@{k}": recall[f"Recall@{k}"],
107
                        f"precision@{k}": precision[f"P@{k}"],
108
                    }
109
                )
110
            print("-------------------------------------")
111

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.