dream

Форк
0
/
pyserini_ranker.py 
91 строка · 3.2 Кб
1
# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import json
16
from logging import getLogger
17
from typing import List, Any, Tuple
18

19
from pyserini.search import SimpleSearcher
20

21
from deeppavlov.core.commands.utils import expand_path
22
from deeppavlov.core.common.registry import register
23
from deeppavlov.core.models.estimator import Component
24

25
logger = getLogger(__name__)
26

27

28
@register("pyserini_ranker")
29
class PyseriniRanker(Component):
30
    def __init__(
31
        self,
32
        index_folder: str,
33
        n_threads: int = 1,
34
        top_n: int = 5,
35
        text_column_name: str = "contents",
36
        return_scores: bool = False,
37
        *args,
38
        **kwargs,
39
    ):
40
        self.searcher = SimpleSearcher(str(expand_path(index_folder)))
41
        self.n_threads = n_threads
42
        self.top_n = top_n
43
        self.text_column_name = text_column_name
44
        self.return_scores = return_scores
45

46
    def __call__(self, questions: List[str]) -> Tuple[List[Any], List[float]]:
47
        docs_batch = []
48
        scores_batch = []
49
        _doc_ids_batch = []
50

51
        if len(questions) == 1:
52
            for question in questions:
53
                res = self.searcher.search(question, self.top_n)
54
                docs, doc_ids, scores = self._processing_search_result(res)
55
                docs_batch.append(docs)
56
                scores_batch.append(scores)
57
                _doc_ids_batch.append(doc_ids)
58
        else:
59
            n_batches = len(questions) // self.n_threads + int(len(questions) % self.n_threads > 0)
60
            for i in range(n_batches):
61
                questions_cur = questions[i * self.n_threads : (i + 1) * self.n_threads]
62
                qids_cur = list(range(len(questions_cur)))
63
                res_batch = self.searcher.batch_search(questions_cur, qids_cur, self.top_n, self.n_threads)
64
                for qid in qids_cur:
65
                    res = res_batch.get(qid)
66
                    docs, doc_ids, scores = self._processing_search_result(res)
67
                    docs_batch.append(docs)
68
                    scores_batch.append(scores)
69
                    _doc_ids_batch.append(doc_ids)
70

71
        logger.debug(f"found docs {_doc_ids_batch}")
72

73
        if self.return_scores:
74
            return docs_batch, scores_batch
75
        else:
76
            return docs_batch
77

78
    @staticmethod
79
    def _processing_search_result(res):
80
        docs = []
81
        doc_ids = []
82
        scores = []
83
        for elem in res:
84
            doc = json.loads(elem.raw)
85
            score = elem.score
86
            if doc and isinstance(doc, dict):
87
                docs.append(doc.get("contents", ""))
88
                doc_ids.append(doc.get("id", ""))
89
                scores.append(score)
90

91
        return docs, doc_ids, scores
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.