dream

Форк
0
247 строк · 10.3 Кб
1
import json
2
import logging
3
import nltk
4
import os
5
import pickle
6
import re
7
import time
8

9
import numpy as np
10
import sentry_sdk
11
from flask import Flask, request, jsonify
12
from sentry_sdk.integrations.flask import FlaskIntegration
13
from deeppavlov import build_model
14

15
from common.fact_retrieval import topic_titles, find_topic_titles
16
from common.wiki_skill import find_all_titles, find_paragraph, delete_hyperlinks, WIKI_BADLIST
17

18
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
19
logger = logging.getLogger(__name__)
20
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
21

22
FILTER_FREQ = False
23

24
CONFIG = os.getenv("CONFIG")
25
CONFIG_PAGE_EXTRACTOR = os.getenv("CONFIG_WIKI")
26
CONFIG_WOW_PAGE_EXTRACTOR = os.getenv("CONFIG_WHOW")
27
N_FACTS = int(os.getenv("N_FACTS", 3))
28

29
DATA_GOOGLE_10K_ENG_NO_SWEARS = "common/google-10000-english-no-swears.txt"
30
DATA_SENTENCES = "data/sentences.pickle"
31

32
re_tokenizer = re.compile(r"[\w']+|[^\w ]")
33

34
with open(DATA_GOOGLE_10K_ENG_NO_SWEARS, "r") as fl:
35
    lines = fl.readlines()
36
    freq_words = [line.strip() for line in lines]
37
    freq_words = set(freq_words[:800])
38

39
with open("%s" % DATA_SENTENCES, "rb") as fl:
40
    test_sentences = pickle.load(fl)
41

42
try:
43
    fact_retrieval = build_model(CONFIG, download=True)
44

45
    with open("/root/.deeppavlov/downloads/wikidata/entity_types_sets.pickle", "rb") as fl:
46
        entity_types_sets = pickle.load(fl)
47

48
    page_extractor = build_model(CONFIG_PAGE_EXTRACTOR, download=True)
49
    logger.info("model loaded, test query processed")
50

51
    whow_page_extractor = build_model(CONFIG_WOW_PAGE_EXTRACTOR, download=True)
52

53
    with open("/root/.deeppavlov/downloads/wikihow/wikihow_topics.json", "r") as fl:
54
        wikihow_topics = json.load(fl)
55
except Exception as e:
56
    sentry_sdk.capture_exception(e)
57
    logger.exception(e)
58
    raise e
59

60
app = Flask(__name__)
61

62

63
def get_page_content(page_title):
64
    page_content = {}
65
    try:
66
        if page_title:
67
            page_content_batch, main_pages_batch = page_extractor([[page_title]])
68
            if page_content_batch and page_content_batch[0]:
69
                page_content = page_content_batch[0][0]
70
    except Exception as e:
71
        sentry_sdk.capture_exception(e)
72
        logger.exception(e)
73

74
    return page_content
75

76

77
def get_wikihow_content(page_title):
78
    page_content = {}
79
    try:
80
        if page_title:
81
            page_content_batch = whow_page_extractor([[page_title]])
82
            if page_content_batch and page_content_batch[0]:
83
                page_content = page_content_batch[0][0]
84
    except Exception as e:
85
        sentry_sdk.capture_exception(e)
86
        logger.exception(e)
87

88
    return page_content
89

90

91
def find_sentences(paragraphs):
92
    sentences_list = []
93
    if paragraphs:
94
        paragraph = paragraphs[0]
95
        paragraph, mentions, mention_pages = delete_hyperlinks(paragraph)
96
        sentences = nltk.sent_tokenize(paragraph)
97
        cur_len = 0
98
        max_len = 50
99
        for sentence in sentences:
100
            words = re.findall(re_tokenizer, sentence)
101
            if cur_len + len(words) < max_len and not re.findall(WIKI_BADLIST, sentence):
102
                sentences_list.append(sentence)
103
                cur_len += len(words)
104
    return sentences_list
105

106

107
def find_facts(entity_substr_batch, entity_ids_batch, entity_pages_batch):
108
    facts_batch = []
109
    for entity_substr_list, entity_ids_list, entity_pages_list in zip(
110
        entity_substr_batch, entity_ids_batch, entity_pages_batch
111
    ):
112
        facts_list = []
113
        for entity_substr, entity_ids, entity_pages in zip(entity_substr_list, entity_ids_list, entity_pages_list):
114
            for entity_id, entity_page in zip(entity_ids, entity_pages):
115
                for entity_types_substr in entity_types_sets:
116
                    if entity_id in entity_types_sets[entity_types_substr]:
117
                        logger.info(f"found_entity_types_substr {entity_types_substr} entity_page {entity_page}")
118
                        if entity_types_substr in {"food", "fruit", "vegetable", "berry"}:
119
                            found_page_title = ""
120
                            entity_tokens = set(re.findall(re_tokenizer, entity_substr))
121
                            food_subtopics = wikihow_topics["Food and Entertaining"]
122
                            for subtopic in food_subtopics:
123
                                page_titles = food_subtopics[subtopic]
124
                                for page_title in page_titles:
125
                                    page_title_tokens = set(page_title.lower().split("-"))
126
                                    if entity_tokens.intersection(page_title_tokens):
127
                                        found_page_title = page_title
128
                                        break
129
                                if found_page_title:
130
                                    break
131
                            if found_page_title:
132
                                page_content = get_wikihow_content(found_page_title)
133
                                if page_content:
134
                                    page_title_clean = found_page_title.lower().replace("-", " ")
135
                                    intro = page_content["intro"]
136
                                    sentences = nltk.sent_tokenize(intro)
137
                                    facts_list.append(
138
                                        {
139
                                            "entity_substr": entity_substr,
140
                                            "entity_type": entity_types_substr,
141
                                            "facts": [{"title": page_title_clean, "sentences": sentences}],
142
                                        }
143
                                    )
144
                        else:
145
                            facts = []
146
                            page_content = get_page_content(entity_page)
147
                            all_titles = find_all_titles([], page_content)
148
                            if entity_types_substr in topic_titles:
149
                                cur_topic_titles = topic_titles[entity_types_substr]
150
                                page_titles = find_topic_titles(all_titles, cur_topic_titles)
151
                                for title, page_title in page_titles:
152
                                    paragraphs = find_paragraph(page_content, page_title)
153
                                    sentences_list = find_sentences(paragraphs)
154
                                    if sentences_list:
155
                                        facts.append({"title": title, "sentences": sentences_list})
156
                                if facts:
157
                                    facts_list.append(
158
                                        {
159
                                            "entity_substr": entity_substr,
160
                                            "entity_type": entity_types_substr,
161
                                            "facts": list(np.random.choice(facts, size=N_FACTS, replace=False)),
162
                                        }
163
                                    )
164
        facts_batch.append(
165
            list(np.random.choice(facts_list, size=N_FACTS, replace=False)) if len(facts_list) > 0 else facts_list
166
        )
167
    return facts_batch
168

169

170
@app.route("/model", methods=["POST"])
171
def respond():
172
    st_time = time.time()
173
    cur_utt = request.json.get("human_sentences", [" "])
174
    dialog_history = request.json.get("dialog_history", [" "])
175
    cur_utt = [utt.lstrip("alexa") for utt in cur_utt]
176
    nounphr_list = request.json.get("nounphrases", [])
177
    if FILTER_FREQ:
178
        nounphr_list = [
179
            [nounphrase for nounphrase in nounphrases if nounphrase not in freq_words] for nounphrases in nounphr_list
180
        ]
181
    if not nounphr_list:
182
        nounphr_list = [[] for _ in cur_utt]
183

184
    entity_substr = request.json.get("entity_substr", [])
185
    if not entity_substr:
186
        entity_substr = [[] for _ in cur_utt]
187
    entity_pages = request.json.get("entity_pages", [])
188
    if not entity_pages:
189
        entity_pages = [[] for _ in cur_utt]
190
    entity_pages_titles = request.json.get("entity_pages_titles", [])
191
    if not entity_pages_titles:
192
        entity_pages_titles = [[] for _ in cur_utt]
193
    entity_ids = request.json.get("entity_ids", [])
194
    if not entity_ids:
195
        entity_ids = [[] for _ in cur_utt]
196
    logger.info(
197
        f"cur_utt {cur_utt} dialog_history {dialog_history} nounphr_list {nounphr_list} entity_pages {entity_pages}"
198
    )
199

200
    nf_numbers, f_utt, f_dh, f_nounphr_list, f_entity_pages = [], [], [], [], []
201
    for n, (utt, dh, nounphrases, input_pages) in enumerate(zip(cur_utt, dialog_history, nounphr_list, entity_pages)):
202
        if utt not in freq_words and nounphrases:
203
            f_utt.append(utt)
204
            f_dh.append(dh)
205
            f_nounphr_list.append(nounphrases)
206
            f_entity_pages.append(input_pages)
207
        else:
208
            nf_numbers.append(n)
209

210
    out_res = [{"facts": [], "topic_facts": []} for _ in cur_utt]
211
    try:
212
        facts_batch = find_facts(entity_substr, entity_ids, entity_pages_titles)
213
        logger.info(f"f_utt {f_utt}")
214
        if f_utt:
215
            fact_res = fact_retrieval(f_utt) if len(f_utt[0].split()) > 3 else fact_retrieval(f_dh)
216
            if fact_res:
217
                fact_res = fact_res[0]
218
            fact_res = [[fact.replace('""', '"') for fact in facts] for facts in fact_res]
219

220
            out_res = []
221
            cnt_fnd = 0
222
            for i in range(len(cur_utt)):
223
                if i in nf_numbers:
224
                    out_res.append({})
225
                else:
226
                    if cnt_fnd < len(fact_res):
227
                        out_res.append(
228
                            {
229
                                "topic_facts": facts_batch[cnt_fnd],
230
                                "facts": list(np.random.choice(fact_res[cnt_fnd], size=N_FACTS, replace=False))
231
                                if len(fact_res[cnt_fnd]) > 0
232
                                else fact_res[cnt_fnd],
233
                            }
234
                        )
235
                        cnt_fnd += 1
236
                    else:
237
                        out_res.append({"facts": [], "topic_facts": []})
238
    except Exception as e:
239
        sentry_sdk.capture_exception(e)
240
        logger.exception(e)
241
    total_time = time.time() - st_time
242
    logger.info(f"fact_retrieval exec time: {total_time:.3f}s")
243
    return jsonify(out_res)
244

245

246
if __name__ == "__main__":
247
    app.run(debug=False, host="0.0.0.0", port=3000)
248

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.