dream

Форк
0
147 строк · 5.8 Кб
1
# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import re
16
import time
17
from logging import getLogger
18
from typing import List, Any, Tuple
19

20
import numpy as np
21
import pymorphy2
22

23
from deeppavlov.core.common.registry import register
24
from deeppavlov.core.models.estimator import Component
25
from deeppavlov.models.vectorizers.hashing_tfidf_vectorizer import HashingTfIdfVectorizer
26

27
logger = getLogger(__name__)
28

29

30
@register("tfidf_ranker")
31
class TfidfRanker(Component):
32
    """Rank documents according to input strings.
33

34
    Args:
35
        vectorizer: a vectorizer class
36
        top_n: a number of doc ids to return
37
        active: whether to return a number specified by :attr:`top_n` (``True``) or all ids
38
         (``False``)
39

40
    Attributes:
41
        top_n: a number of doc ids to return
42
        vectorizer: an instance of vectorizer class
43
        active: whether to return a number specified by :attr:`top_n` or all ids
44
        index2doc: inverted :attr:`doc_index`
45
        iterator: a dataset iterator used for generating batches while fitting the vectorizer
46

47
    """
48

49
    def __init__(
50
        self,
51
        vectorizer: HashingTfIdfVectorizer,
52
        top_n=5,
53
        out_top_n=5,
54
        active: bool = True,
55
        filter_flag: bool = False,
56
        **kwargs,
57
    ):
58
        self.top_n = top_n
59
        self.out_top_n = out_top_n
60
        self.vectorizer = vectorizer
61
        self.active = active
62
        self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
63
        self.lemmatizer = pymorphy2.MorphAnalyzer()
64
        self.filter_flag = filter_flag
65
        self.numbers = 0
66

67
    def __call__(
68
        self, questions: List[str], entity_substr_batch: List[List[str]] = None, tags_batch: List[List[str]] = None
69
    ) -> Tuple[List[Any], List[float]]:
70
        """Rank documents and return top n document titles with scores.
71

72
        Args:
73
            questions: list of queries used in ranking
74

75
        Returns:
76
            a tuple of selected doc ids and their scores
77
        """
78

79
        tm_st = time.time()
80
        batch_doc_ids, batch_docs_scores = [], []
81

82
        q_tfidfs = self.vectorizer(questions)
83
        if entity_substr_batch is None:
84
            entity_substr_batch = [[] for _ in questions]
85
            tags_batch = [[] for _ in questions]
86

87
        for question, q_tfidf, entity_substr_list, tags_list in zip(
88
            questions, q_tfidfs, entity_substr_batch, tags_batch
89
        ):
90
            if self.filter_flag:
91
                entity_substr_for_search = []
92
                if entity_substr_list and not tags_list:
93
                    tags_list = ["NOUN" for _ in entity_substr_list]
94
                for entity_substr, tag in zip(entity_substr_list, tags_list):
95
                    if tag in {"PER", "PERSON", "PRODUCT", "WORK_OF_ART", "COUNTRY", "ORGANIZATION", "NOUN"}:
96
                        entity_substr_for_search.append(entity_substr)
97
                if not entity_substr_for_search:
98
                    for entity_substr, tag in zip(entity_substr_list, tags_list):
99
                        if tag in {"LOCATION", "LOC", "ORG"}:
100
                            entity_substr_for_search.append(entity_substr)
101
                if not entity_substr_for_search:
102
                    question_tokens = re.findall(self.re_tokenizer, question)
103
                    for question_token in question_tokens:
104
                        if self.lemmatizer.parse(question_token)[0].tag.POS == "NOUN" and self.lemmatizer.parse(
105
                            question_token
106
                        )[0].normal_form not in {"мир", "земля", "планета", "человек"}:
107
                            entity_substr_for_search.append(question_token)
108

109
                nonzero_scores = set()
110

111
                if entity_substr_for_search:
112
                    ent_tfidf = self.vectorizer([", ".join(entity_substr_for_search)])[0]
113
                    ent_scores = ent_tfidf * self.vectorizer.tfidf_matrix
114
                    ent_scores = np.squeeze(ent_scores.toarray())
115
                    nonzero_scores = set(np.nonzero(ent_scores)[0])
116

117
            scores = q_tfidf * self.vectorizer.tfidf_matrix
118
            scores = np.squeeze(scores.toarray() + 0.0001)  # add a small value to eliminate zero scores
119

120
            if self.active:
121
                thresh = self.top_n
122
            else:
123
                thresh = len(self.vectorizer.doc_index)
124

125
            if thresh >= len(scores):
126
                o = np.argpartition(-scores, len(scores) - 1)[0:thresh]
127
            else:
128
                o = np.argpartition(-scores, thresh)[0:thresh]
129
            o_sort = o[np.argsort(-scores[o])]
130

131
            filtered_o_sort = []
132
            if self.filter_flag and nonzero_scores:
133
                filtered_o_sort = [elem for elem in o_sort if elem in nonzero_scores]
134
                if filtered_o_sort:
135
                    filtered_o_sort = np.array(filtered_o_sort)
136
            if isinstance(filtered_o_sort, list):
137
                filtered_o_sort = o_sort
138

139
            doc_scores = scores[filtered_o_sort].tolist()
140
            doc_ids = [self.vectorizer.index2doc.get(i, "") for i in filtered_o_sort]
141

142
            batch_doc_ids.append(doc_ids[: self.out_top_n])
143
            batch_docs_scores.append(doc_scores[: self.out_top_n])
144
        tm_end = time.time()
145
        logger.info(f"tfidf ranking time: {tm_end - tm_st} num doc_ids {len(batch_doc_ids[0])}")
146

147
        return batch_doc_ids, batch_docs_scores
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.