paddlenlp

Форк
0
/
retrieval_benchmarks.py 
177 строк · 6.5 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import csv
15
import math
16
import time
17
from collections import defaultdict
18
from typing import Dict, List, cast
19

20
from datasets import load_dataset
21
from mteb.abstasks import AbsTaskRetrieval
22

23
from paddlenlp import Taskflow
24

25
csv.field_size_limit(500 * 1024 * 1024)
26

27

28
class PaddleModel:
29
    def __init__(
30
        self, query_model, corpus_model, batch_size=16, max_seq_len=512, sep=" ", pooling_mode="mean_tokens", **kwargs
31
    ):
32
        self.query_model = Taskflow(
33
            "feature_extraction",
34
            model=query_model,
35
            pooling_mode=pooling_mode,
36
            max_seq_len=max_seq_len,
37
            batch_size=batch_size,
38
            _static_mode=True,
39
        )
40
        self.corpus_model = Taskflow(
41
            "feature_extraction",
42
            model=corpus_model,
43
            pooling_mode=pooling_mode,
44
            max_seq_len=max_seq_len,
45
            batch_size=batch_size,
46
            _static_mode=True,
47
        )
48
        self.sep = sep
49

50
    def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
51
        return self.query_model(queries, batch_size=batch_size, **kwargs)["features"].detach().cpu().numpy()
52

53
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
54
        if type(corpus) is dict:
55
            sentences = [
56
                (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
57
                if "title" in corpus
58
                else corpus["text"][i].strip()
59
                for i in range(len(corpus["text"]))
60
            ]
61
        else:
62
            sentences = [
63
                (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
64
                for doc in corpus
65
            ]
66
        return self.corpus_model(sentences, batch_size=batch_size, **kwargs)["features"].detach().cpu().numpy()
67

68

69
class T2RRetrieval(AbsTaskRetrieval):
70
    def __init__(self, num_max_passages: "int | None" = None, **kwargs):
71
        super().__init__(**kwargs)
72
        self.num_max_passages = num_max_passages or math.inf
73

74
    @property
75
    def description(self):
76
        return {
77
            "name": "T2RankingRetrieval",
78
            "reference": "https://huggingface.co/datasets/THUIR/T2Ranking",
79
            "type": "Retrieval",
80
            "category": "s2p",
81
            "eval_splits": ["dev"],
82
            "eval_langs": ["zh"],
83
            "main_score": "ndcg_at_10",
84
        }
85

86
    def evaluate(
87
        self,
88
        model_query,
89
        model_corpus,
90
        split="test",
91
        batch_size=32,
92
        corpus_chunk_size=None,
93
        target_devices=None,
94
        score_function="cos_sim",
95
        **kwargs
96
    ):
97
        from beir.retrieval.evaluation import EvaluateRetrieval
98

99
        if not self.data_loaded:
100
            self.load_data()
101
        corpus, queries, relevant_docs = self.corpus[split], self.queries[split], self.relevant_docs[split]
102

103
        from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
104

105
        model = PaddleModel(model_query, model_corpus)
106

107
        model = DRES(
108
            model,
109
            batch_size=batch_size,
110
            corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
111
            **kwargs,
112
        )
113
        retriever = EvaluateRetrieval(model, score_function=score_function)  # or "cos_sim" or "dot"
114
        start_time = time.time()
115
        results = retriever.retrieve(corpus, queries)
116
        end_time = time.time()
117
        print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
118

119
        ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values)
120
        mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")
121

122
        scores = {
123
            **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
124
            **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
125
            **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
126
            **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
127
            **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
128
        }
129
        print(scores)
130
        return scores
131

132
    def load_data(self, **kwargs):
133
        corpus, queries, qrels = load_t2ranking_for_retraviel(self.num_max_passages)
134
        self.corpus, self.queries, self.relevant_docs = {}, {}, {}
135
        self.corpus["dev"] = corpus
136
        self.queries["dev"] = queries
137
        self.relevant_docs["dev"] = qrels
138
        self.data_loaded = True
139

140

141
def load_t2ranking_for_retraviel(num_max_passages: float):
142
    collection_dataset = load_dataset("THUIR/T2Ranking", "collection")["train"]  # type: ignore
143
    dev_queries_dataset = load_dataset("THUIR/T2Ranking", "queries.dev")["train"]  # type: ignore
144
    dev_rels_dataset = load_dataset("THUIR/T2Ranking", "qrels.dev")["train"]  # type: ignore
145
    corpus = {}
146
    for index in range(min(len(collection_dataset), num_max_passages)):
147
        record = collection_dataset[index]
148
        record = cast(dict, record)
149
        pid: int = record["pid"]
150
        corpus[str(pid)] = {"text": record["text"]}
151
    queries = {}
152
    for record in dev_queries_dataset:
153
        record = cast(dict, record)
154
        queries[str(record["qid"])] = record["text"]
155

156
    all_qrels = defaultdict(dict)
157
    for record in dev_rels_dataset:
158
        record = cast(dict, record)
159
        pid: int = record["pid"]
160
        if pid > num_max_passages:
161
            continue
162
        all_qrels[str(record["qid"])][str(record["pid"])] = record["rel"]
163
    valid_qrels = {}
164
    for qid, qrels in all_qrels.items():
165
        if len(set(list(qrels.values())) - set([0])) >= 1:
166
            valid_qrels[qid] = qrels
167
    valid_queries = {}
168
    for qid, query in queries.items():
169
        if qid in valid_qrels:
170
            valid_queries[qid] = query
171
    print(f"valid qrels: {len(valid_qrels)}")
172
    return corpus, valid_queries, valid_qrels
173

174

175
if __name__ == "__main__":
176
    tasks = T2RRetrieval(num_max_passages=10000)
177
    tasks.evaluate(model_query="moka-ai/m3e-base", model_corpus="moka-ai/m3e-base", split="dev")
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.