fastrag

Форк
0
/
nq-plaid.py 
64 строки · 2.0 Кб
1
import logging
2
import os
3
import pathlib
4

5
import tqdm
6
from beir import LoggingHandler, util
7
from beir.datasets.data_loader import GenericDataLoader
8
from beir.retrieval.evaluation import EvaluateRetrieval
9
from haystack import Pipeline
10

11
from fastrag.retrievers.colbert import ColBERTRetriever
12
from fastrag.stores import PLAIDDocumentStore
13

14
logging.getLogger().setLevel(logging.INFO)
15
logging.basicConfig(
16
    format="%(asctime)s - %(message)s",
17
    datefmt="%Y-%m-%d %H:%M:%S",
18
    level=logging.INFO,
19
    handlers=[LoggingHandler()],
20
)
21

22
dataset = "nq"
23
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
24
out_dir = os.path.join(pathlib.Path(".").absolute(), "benchmarks", "datasets")
25
data_path = util.download_and_unzip(url, out_dir)
26

27
logging.info("Loading dataset...")
28
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
29

30
k_values = [10, 20, 50, 100]
31

32
beir_retriever = EvaluateRetrieval(k_values=k_values)
33

34
logging.info("Loading PLAID index...")
35
document_store = PLAIDDocumentStore(
36
    index_path="/path/to/index",
37
    checkpoint_path="/path/to/checkpoint",
38
    collection_path="/path/to/collection.tsv",
39
)
40

41
retriever = ColBERTRetriever(document_store=document_store)
42

43
p = Pipeline()
44
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
45

46

47
def retrieve_queries(queries: dict, k: int):
48
    ans = {}
49
    for q, query in tqdm.tqdm(queries.items(), "queries"):
50
        ans[q] = {
51
            # need to concat "doc" to the PLAID IDs to match the qrels IDs
52
            "doc" + str(doc.id): doc.score
53
            for doc in p.run(query, params={"Retriever": {"top_k": k}})["documents"]
54
        }
55
    return ans
56

57

58
logging.info("Querying documents...")
59
results = retrieve_queries(queries, max(k_values))
60

61

62
#### Evaluate the retrieval using NDCG@k, MAP@K ...
63
logging.info("Retriever evaluation for k in: {}".format(beir_retriever.k_values))
64
ndcg, _map, recall, precision = beir_retriever.evaluate(qrels, results, beir_retriever.k_values)
65

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.