dream

Форк
0
/
insert_scenario.py 
952 строки · 44.3 Кб
1
import logging
2
import os
3
import random
4
import re
5
import nltk
6
import requests
7
import sentry_sdk
8
import common.constants as common_constants
9
import common.dialogflow_framework.utils.state as state_utils
10
from common.universal_templates import if_chat_about_particular_topic
11

12
from common.wiki_skill import (
13
    check_condition,
14
    find_entity_by_types,
15
    check_nounphr,
16
    find_page_title,
17
    find_paragraph,
18
    delete_hyperlinks,
19
    find_all_titles,
20
    used_types_dict,
21
    NEWS_MORE,
22
    WIKI_BADLIST,
23
    QUESTION_TEMPLATES,
24
    QUESTION_TEMPLATES_SHORT,
25
    CONF_DICT,
26
)
27
from common.universal_templates import CONTINUE_PATTERN
28
from common.utils import is_no, is_yes
29

30
sentry_sdk.init(os.getenv("SENTRY_DSN"))
31
logger = logging.getLogger(__name__)
32
WIKI_FACTS_URL = os.getenv("WIKI_FACTS_URL")
33

34
memory = {}
35

36

37
titles_by_type = {}
38
for elem in used_types_dict:
39
    types = elem.get("types", [])
40
    titles = elem["titles"]
41
    for tp in types:
42
        titles_by_type[tp] = titles
43

44
titles_by_entity_substr = {}
45
page_titles_by_entity_substr = {}
46
for elem in used_types_dict:
47
    entity_substrings = elem.get("entity_substr", [])
48
    titles = elem["titles"]
49
    page_title = elem.get("page_title", "")
50
    for substr in entity_substrings:
51
        titles_by_entity_substr[substr] = titles
52
        if page_title:
53
            page_titles_by_entity_substr[substr] = page_title
54

55
questions_by_entity_substr = {}
56
for elem in used_types_dict:
57
    entity_substrings = elem.get("entity_substr", [])
58
    question = elem.get("intro_question", "")
59
    if question:
60
        for substr in entity_substrings:
61
            questions_by_entity_substr[substr] = question
62

63
wikihowq_by_substr = {}
64
for elem in used_types_dict:
65
    entity_substrings = elem.get("entity_substr", [])
66
    wikihow_info = elem.get("wikihow_info", {})
67
    if wikihow_info:
68
        for substr in entity_substrings:
69
            wikihowq_by_substr[substr] = wikihow_info
70

71

72
def get_page_content(page_title, cache_page_dict=None):
73
    page_content = {}
74
    main_pages = {}
75
    try:
76
        if page_title:
77
            if cache_page_dict and page_title in cache_page_dict:
78
                page_content = cache_page_dict[page_title]["page_content"]
79
                main_pages = cache_page_dict[page_title]["main_pages"]
80
            else:
81
                res = requests.post(WIKI_FACTS_URL, json={"wikipedia_titles": [[page_title]]}, timeout=1.0).json()
82
                if res and res[0]["main_pages"] and res[0]["wikipedia_content"]:
83
                    page_content = res[0]["wikipedia_content"][0]
84
                    main_pages = res[0]["main_pages"][0]
85
    except Exception as e:
86
        sentry_sdk.capture_exception(e)
87
        logger.exception(e)
88

89
    return page_content, main_pages
90

91

92
def get_wikihow_content(page_title):
93
    page_content = {}
94
    try:
95
        if page_title:
96
            res = requests.post(WIKI_FACTS_URL, json={"wikihow_titles": [[page_title]]}, timeout=1.0).json()
97
            if res and res[0]["wikihow_content"]:
98
                page_content = res[0]["wikihow_content"][0]
99
    except Exception as e:
100
        sentry_sdk.capture_exception(e)
101
        logger.exception(e)
102

103
    return page_content
104

105

106
def get_titles(found_entity_substr, found_entity_types, page_content):
107
    all_titles = find_all_titles([], page_content)
108
    titles_we_use = []
109
    titles_q = {}
110
    for tp in found_entity_types:
111
        tp_titles = titles_by_type.get(tp, {})
112
        titles_we_use += list(tp_titles.keys())
113
        titles_q = {**titles_q, **tp_titles}
114
    substr_titles = titles_by_entity_substr.get(found_entity_substr, {})
115
    titles_we_use += list(substr_titles.keys())
116
    titles_q = {**titles_q, **substr_titles}
117
    return titles_q, titles_we_use, all_titles
118

119

120
def get_page_title(vars, entity_substr):
121
    found_page = ""
122
    if entity_substr in page_titles_by_entity_substr:
123
        found_page = page_titles_by_entity_substr[entity_substr]
124
    else:
125
        annotations = state_utils.get_last_human_utterance(vars)["annotations"]
126
        el = annotations.get("entity_linking", [])
127
        for entity in el:
128
            if isinstance(entity, dict) and entity["entity_substr"] == entity_substr:
129
                found_pages_titles = entity["pages_titles"]
130
                if found_pages_titles:
131
                    found_page = found_pages_titles[0]
132
    logger.info(f"found_page {found_page}")
133
    return found_page
134

135

136
def make_facts_str(paragraphs):
137
    facts_str = ""
138
    mentions_list = []
139
    mention_pages_list = []
140
    paragraph = ""
141
    if paragraphs:
142
        paragraph = paragraphs[0]
143
    sentences = nltk.sent_tokenize(paragraph)
144
    sentences_list = []
145
    cur_len = 0
146
    max_len = 50
147
    for sentence in sentences:
148
        sanitized_sentence, mentions, mention_pages = delete_hyperlinks(sentence)
149
        words = nltk.word_tokenize(sanitized_sentence)
150
        if cur_len + len(words) < max_len and not re.findall(WIKI_BADLIST, sanitized_sentence):
151
            sentences_list.append(sanitized_sentence)
152
            cur_len += len(words)
153
            mentions_list += mentions
154
            mention_pages_list += mention_pages
155
    if sentences_list:
156
        facts_str = " ".join(sentences_list)
157
    cur_len = 0
158
    if sentences and not sentences_list:
159
        sentence = sentences[0]
160
        sanitized_sentence, mentions, mention_pages = delete_hyperlinks(sentence)
161
        sentence_parts = sanitized_sentence.split(", ")
162
        mentions_list += mentions
163
        mention_pages_list += mention_pages
164
        for part in sentence_parts:
165
            words = nltk.word_tokenize(part)
166
            if cur_len + len(words) < max_len and not re.findall(WIKI_BADLIST, part):
167
                sentences_list.append(part)
168
                cur_len += len(words)
169
            facts_str = ", ".join(sentences_list)
170
            if facts_str and not facts_str.endswith("."):
171
                facts_str = f"{facts_str}."
172
    return facts_str, mentions_list, mention_pages_list
173

174

175
def check_utt_cases(vars, utt_info):
176
    flag = False
177
    user_uttr = state_utils.get_last_human_utterance(vars)
178
    bot_uttr = state_utils.get_last_bot_utterance(vars)
179
    shared_memory = state_utils.get_shared_memory(vars)
180
    utt_cases = utt_info.get("utt_cases", [])
181
    if utt_cases:
182
        for utt_case in utt_cases:
183
            condition = utt_case["cond"]
184
            if check_condition(condition, user_uttr, bot_uttr, shared_memory):
185
                flag = True
186
    else:
187
        flag = True
188
    return flag
189

190

191
def extract_and_save_subtopic(vars, topic_config, found_topic):
192
    user_uttr = state_utils.get_last_human_utterance(vars)
193
    bot_uttr = state_utils.get_last_bot_utterance(vars)
194
    shared_memory = state_utils.get_shared_memory(vars)
195
    expected_subtopic_info_list = shared_memory.get("expected_subtopic_info", [])
196
    subtopics = shared_memory.get("subtopics", [])
197
    for expected_subtopic_info in expected_subtopic_info_list:
198
        if isinstance(expected_subtopic_info, str) and found_topic:
199
            global_subtopic_info = topic_config[found_topic].get("expected_subtopics", {})
200
            if expected_subtopic_info in global_subtopic_info:
201
                expected_subtopic_info = global_subtopic_info[expected_subtopic_info]
202
        if isinstance(expected_subtopic_info, dict):
203
            subtopic = expected_subtopic_info["subtopic"]
204
            condition = expected_subtopic_info["cond"]
205
            flag = check_condition(condition, user_uttr, bot_uttr, shared_memory)
206
            logger.info(f"expected_subtopic_info {expected_subtopic_info} flag {flag}")
207
            if flag and subtopic not in subtopics:
208
                subtopics.append(subtopic)
209
                prev_available_utterances = shared_memory.get("available_utterances", [])
210
                available_utterances = expected_subtopic_info.get("available_utterances", [])
211
                for utt_key in available_utterances:
212
                    if utt_key not in prev_available_utterances:
213
                        prev_available_utterances.append(utt_key)
214
                state_utils.save_to_shared_memory(vars, available_utterances=prev_available_utterances)
215
            if flag:
216
                state_utils.save_to_shared_memory(vars, subtopics=subtopics)
217
                state_utils.save_to_shared_memory(vars, expected_subtopic_info={})
218
                break
219

220

221
def find_trigger(vars, triggers):
222
    user_uttr = state_utils.get_last_human_utterance(vars)
223
    annotations = user_uttr["annotations"]
224
    if "entity_types" in triggers:
225
        found_entity_substr, found_entity_types, _ = find_entity_by_types(annotations, triggers["entity_types"])
226
        curr_page = get_page_title(vars, found_entity_substr)
227
        if curr_page:
228
            return found_entity_substr, found_entity_types, curr_page, ""
229
    if "entity_substr" in triggers:
230
        for entity_info in triggers["entity_substr"]:
231
            substrings = entity_info["substr"]
232
            for substr_info in substrings:
233
                found_substr = check_nounphr(annotations, substr_info["substr"])
234
                if found_substr:
235
                    wikipedia_page = substr_info.get("wikipedia_page", "")
236
                    wikihow_page = substr_info.get("wikihow_page", "")
237
                    return found_substr, [], wikipedia_page, wikihow_page
238
    return "", [], "", ""
239

240

241
def delete_topic_info(vars):
242
    state_utils.save_to_shared_memory(vars, special_topic="")
243
    state_utils.save_to_shared_memory(vars, expected_subtopic_info={})
244
    state_utils.save_to_shared_memory(vars, available_utterances=[])
245
    state_utils.save_to_shared_memory(vars, subtopics=[])
246
    state_utils.save_to_shared_memory(vars, cur_facts=[])
247
    state_utils.save_to_shared_memory(vars, used_utt_nums={})
248
    state_utils.save_to_shared_memory(vars, cur_mode="")
249
    state_utils.save_to_shared_memory(vars, ackn=[])
250

251

252
def preprocess_wikihow_page(article_content):
253
    page_content_list = []
254
    article_content = list(article_content.items())
255
    for title_num, (title, paragraphs) in enumerate(article_content):
256
        if title != "intro":
257
            for n, paragraph in enumerate(paragraphs):
258
                facts_str = ""
259
                question = ""
260
                sentences = nltk.sent_tokenize(paragraph)
261
                sentences_list = []
262
                cur_len = 0
263
                max_len = 50
264
                for sentence in sentences:
265
                    words = nltk.word_tokenize(sentence)
266
                    if cur_len + len(words) < max_len:
267
                        sentences_list.append(sentence)
268
                        cur_len += len(words)
269
                if sentences_list:
270
                    facts_str = " ".join(sentences_list)
271
                else:
272
                    cur_len = 0
273
                    sentence_parts = sentences[0].split(", ")
274
                    for part in sentence_parts:
275
                        words = nltk.word_tokenize(part)
276
                        if cur_len + len(words) < max_len:
277
                            sentences_list.append(part)
278
                            cur_len += len(words)
279
                        facts_str = ", ".join(sentences_list)
280

281
                if n == len(paragraphs) - 1 and title_num != len(article_content) - 1:
282
                    next_title = article_content[title_num + 1][0]
283
                    question = f"Would you like to know about {next_title.lower()}?"
284
                elif n != len(paragraphs) - 1:
285
                    question = random.choice(NEWS_MORE)
286
                response_dict = {"facts_str": facts_str, "question": question}
287
                response = f"{facts_str} {question}".strip().replace("  ", " ")
288
                if response:
289
                    page_content_list.append(response_dict)
290
    return page_content_list
291

292

293
def preprocess_wikipedia_page(found_entity_substr, found_entity_types, article_content, predefined_titles=None):
294
    logger.info(f"found_entity_substr {found_entity_substr} found_entity_types {found_entity_types}")
295
    titles_q, titles_we_use, all_titles = get_titles(found_entity_substr, found_entity_types, article_content)
296
    if predefined_titles:
297
        titles_we_use = predefined_titles
298
    logger.info(f"titles_we_use {titles_we_use} all_titles {all_titles}")
299
    facts_list = []
300
    for n, title in enumerate(titles_we_use):
301
        page_title = find_page_title(all_titles, title)
302
        paragraphs = find_paragraph(article_content, page_title)
303
        logger.info(f"page_title {page_title} paragraphs {paragraphs[:2]}")
304
        count_par = 0
305
        for num, paragraph in enumerate(paragraphs):
306
            facts_str, *_ = make_facts_str([paragraph])
307
            if facts_str and facts_str.endswith(".") and len(facts_str.split()) > 4:
308
                facts_list.append((title, facts_str))
309
                count_par += 1
310
            if count_par == 2:
311
                break
312
    logger.info(f"facts_list {facts_list[:3]}")
313
    page_content_list = []
314
    for n, (title, facts_str) in enumerate(facts_list):
315
        if n != len(facts_list) - 1:
316
            next_title = facts_list[n + 1][0]
317
            if next_title != title:
318
                if found_entity_substr.lower() in next_title.lower():
319
                    question_template = random.choice(QUESTION_TEMPLATES_SHORT)
320
                    question = question_template.format(next_title)
321
                else:
322
                    question_template = random.choice(QUESTION_TEMPLATES)
323
                    question = question_template.format(next_title, found_entity_substr)
324
            else:
325
                question = random.choice(NEWS_MORE)
326
            response_dict = {"facts_str": facts_str, "question": question}
327
            response = f"{facts_str} {question}".strip().replace("  ", " ")
328
            if response:
329
                page_content_list.append(response_dict)
330
        else:
331
            page_content_list.append(
332
                {"facts_str": facts_str, "question": f"I was very happy to tell you more about {found_entity_substr}."}
333
            )
334
    logger.info(f"page_content_list {page_content_list}")
335
    return page_content_list
336

337

338
def extract_entity(vars, user_uttr, expected_entity):
339
    annotations = user_uttr["annotations"]
340
    if expected_entity:
341
        logger.info(f"expected_entity {expected_entity}")
342
        if "cobot_entities_type" in expected_entity:
343
            cobot_entities_type = expected_entity["cobot_entities_type"]
344
            nounphrases = annotations.get("cobot_entities", {}).get("labelled_entities", [])
345
            for nounphr in nounphrases:
346
                nounphr_text = nounphr.get("text", "")
347
                nounphr_label = nounphr.get("label", "")
348
                if nounphr_label == cobot_entities_type:
349
                    found_entity = nounphr_text
350
                    return found_entity, {}
351
        if "wiki_parser_types" in expected_entity:
352
            types = expected_entity["wiki_parser_types"]
353
            relations = expected_entity.get("relations", [])
354
            found_entity, found_types, entity_triplets = find_entity_by_types(annotations, types, relations)
355
            if found_entity:
356
                return found_entity, entity_triplets
357
        if "entity_substr" in expected_entity:
358
            substr_info_list = expected_entity["entity_substr"]
359
            for entity, pattern in substr_info_list:
360
                if re.findall(pattern, user_uttr["text"]):
361
                    return entity, {}
362
        if expected_entity.get("any_entity", False):
363
            cobot_entities = annotations.get("cobot_entities", {}).get("entities", [])
364
            if cobot_entities:
365
                return cobot_entities[0], {}
366
    return "", {}
367

368

369
def extract_and_save_entity(vars, topic_config, found_topic):
370
    user_uttr = state_utils.get_last_human_utterance(vars)
371
    shared_memory = state_utils.get_shared_memory(vars)
372
    expected_entities = shared_memory.get("expected_entities", {})
373
    found = False
374
    found_entity = ""
375
    for expected_entity in expected_entities:
376
        if isinstance(expected_entity, dict):
377
            found_entity, entity_triplets = extract_entity(vars, user_uttr, expected_entity)
378
        elif isinstance(expected_entity, str) and found_topic:
379
            topic_expected_entities = topic_config[found_topic].get("expected_entities_info", {})
380
            if expected_entity in topic_expected_entities:
381
                expected_entity = topic_expected_entities[expected_entity]
382
                found_entity, entity_triplets = extract_entity(vars, user_uttr, expected_entity)
383
        logger.info(f"expected_entity {expected_entity} found_entity {found_entity} entity_triplets {entity_triplets}")
384
        if found_entity:
385
            entity_name = expected_entity["name"]
386
            user_info = shared_memory.get("user_info", {})
387
            new_entity_triplets = shared_memory.get("entity_triplets", {})
388
            user_info[entity_name] = found_entity
389
            logger.info(f"extracting entity, user_info {user_info}")
390
            state_utils.save_to_shared_memory(vars, user_info=user_info)
391
            if entity_triplets:
392
                new_entity_triplets = {**new_entity_triplets, **entity_triplets}
393
                state_utils.save_to_shared_memory(vars, entity_triplets=new_entity_triplets)
394
            found = True
395
    if found:
396
        state_utils.save_to_shared_memory(vars, expected_entities={})
397

398

399
def if_facts_agree(vars):
400
    flag = False
401
    user_uttr = state_utils.get_last_human_utterance(vars)
402
    bot_uttr = state_utils.get_last_bot_utterance(vars)
403
    shared_memory = state_utils.get_shared_memory(vars)
404
    cur_facts = shared_memory.get("cur_facts", {})
405
    for fact in cur_facts:
406
        condition = fact["cond"]
407
        flag = check_condition(condition, user_uttr, bot_uttr, shared_memory)
408
        if flag:
409
            break
410
    return flag
411

412

413
def extract_and_save_wikipage(vars, save=False):
414
    flag = False
415
    user_uttr = state_utils.get_last_human_utterance(vars)
416
    bot_uttr = state_utils.get_last_bot_utterance(vars)
417
    shared_memory = state_utils.get_shared_memory(vars)
418
    cur_facts = shared_memory.get("cur_facts", {})
419
    for fact in cur_facts:
420
        wikihow_page = fact.get("wikihow_page", "")
421
        condition = fact["cond"]
422
        checked = check_condition(condition, user_uttr, bot_uttr, shared_memory)
423
        if checked and wikihow_page:
424
            flag = True
425
            if save:
426
                state_utils.save_to_shared_memory(vars, cur_wikihow_page=wikihow_page)
427
                state_utils.save_to_shared_memory(vars, cur_facts={})
428
            break
429
        wikipedia_page = fact.get("wikipedia_page", "")
430
        condition = fact["cond"]
431
        checked = check_condition(condition, user_uttr, bot_uttr, shared_memory)
432
        if checked and wikipedia_page:
433
            flag = True
434
            if save:
435
                state_utils.save_to_shared_memory(vars, cur_wikipedia_page=wikipedia_page)
436
                state_utils.save_to_shared_memory(vars, cur_facts={})
437
            break
438
    return flag
439

440

441
def check_used_subtopic_utt(vars, topic_config, subtopic):
442
    flag = False
443
    shared_memory = state_utils.get_shared_memory(vars)
444
    found_topic = shared_memory.get("special_topic", "")
445
    cur_topic_smalltalk = topic_config[found_topic]["smalltalk"]
446
    used_utt_nums_dict = shared_memory.get("used_utt_nums", {})
447
    used_utt_nums = used_utt_nums_dict.get(found_topic, [])
448
    total, used = 0, 0
449
    if found_topic:
450
        for num, utt_info in enumerate(cur_topic_smalltalk):
451
            if utt_info.get("subtopic", "") == subtopic:
452
                total += 1
453
                if num in used_utt_nums:
454
                    used += 1
455
        logger.info(
456
            f"check_used_subtopic_utt, subtopic {subtopic} total {total} used {used} "
457
            f"used_utt_nums_dict {used_utt_nums_dict}"
458
        )
459
        if total > 0 and total == used:
460
            flag = True
461
    return flag
462

463

464
def make_resp_list(vars, utt_list, topic_config, shared_memory):
465
    resp_list = []
466
    found_topic = shared_memory.get("special_topic", "")
467
    user_info = shared_memory.get("user_info", {})
468
    logger.info(f"make_smalltalk_response, user_info {user_info}")
469
    for utt in utt_list:
470
        utt_slots = re.findall(r"{(.*?)}", utt)
471
        if not utt_slots:
472
            resp_list.append(utt)
473
        else:
474
            entity_triplets = shared_memory.get("entity_triplets", {})
475
            for slot in utt_slots:
476
                slot_value = ""
477
                if slot.startswith("["):
478
                    slot_strip = slot.strip("[]")
479
                    slot_keys = slot_strip.split(", ")
480
                    bot_data = topic_config.get(found_topic, {}).get("bot_data", {})
481
                    if slot_keys and slot_keys[0] == "bot_data" and bot_data:
482
                        slot_value = bot_data
483
                        for key in slot_keys[1:]:
484
                            if key in user_info:
485
                                key = user_info[key]
486
                            slot_value = slot_value[key]
487
                    elif slot_keys and slot_keys[0] != "bot_data" and slot_keys[0] in user_info:
488
                        user_var_name = slot_keys[0]
489
                        user_var_val = user_info[user_var_name]
490
                        relation = slot_keys[1]
491
                        objects = entity_triplets.get(user_var_val, {}).get(relation, "")
492
                        if len(objects) == 1:
493
                            slot_value = objects[0]
494
                        elif len(objects) == 2:
495
                            slot_value = f"{objects[0]} and {objects[1]}"
496
                        elif len(objects) > 2:
497
                            slot_value = ", ".join(objects[:2]) + " and " + objects[2]
498
                        slot_value = slot_value.strip().replace("  ", " ")
499
                else:
500
                    slot_value = user_info.get(slot, "")
501
                if slot_value:
502
                    slot_repl = "{" + slot + "}"
503
                    utt = utt.replace(slot_repl, slot_value)
504
            if "{" not in utt:
505
                resp_list.append(utt)
506
    return resp_list
507

508

509
def check_acknowledgements(vars, topic_config):
510
    response = ""
511
    user_uttr = state_utils.get_last_human_utterance(vars)
512
    bot_uttr = state_utils.get_last_bot_utterance(vars)
513
    shared_memory = state_utils.get_shared_memory(vars)
514
    found_topic = shared_memory.get("special_topic", "")
515
    if found_topic:
516
        ackns = topic_config[found_topic].get("ackn", [])
517
        for ackn in ackns:
518
            condition = ackn["cond"]
519
            if check_condition(condition, user_uttr, bot_uttr, shared_memory):
520
                answer = ackn["answer"]
521
                resp_list = make_resp_list(vars, answer, topic_config, shared_memory)
522
                if resp_list:
523
                    response = " ".join(resp_list).strip().replace("  ", " ")
524
                    break
525
    return response
526

527

528
def answer_users_question(vars, topic_config):
529
    shared_memory = state_utils.get_shared_memory(vars)
530
    found_topic = shared_memory.get("special_topic", "")
531
    answer = ""
532
    user_uttr = state_utils.get_last_human_utterance(vars)
533
    if found_topic:
534
        questions = topic_config[found_topic].get("questions", [])
535
        logger.info(f"user_uttr {user_uttr.get('text', '')} questions {questions}")
536
        for question in questions:
537
            pattern = question["pattern"]
538
            if re.findall(pattern, user_uttr["text"]):
539
                answer = question["answer"]
540
                break
541
    return answer
542

543

544
def check_switch(vars, topic_config):
545
    user_uttr = state_utils.get_last_human_utterance(vars)
546
    bot_uttr = state_utils.get_last_bot_utterance(vars)
547
    shared_memory = state_utils.get_shared_memory(vars)
548
    found_topic = shared_memory.get("special_topic", "")
549
    first_utt = False
550
    utt_can_continue = "can"
551
    utt_conf = 0.0
552
    shared_memory = state_utils.get_shared_memory(vars)
553
    for topic in topic_config:
554
        linkto = topic_config[topic].get("linkto", [])
555
        for phrase in linkto:
556
            if phrase.lower() in bot_uttr["text"].lower():
557
                found_topic = topic
558
                first_utt = True
559
                break
560
        pattern = topic_config[topic].get("pattern", "")
561
        if pattern:
562
            if if_chat_about_particular_topic(user_uttr, bot_uttr, compiled_pattern=pattern):
563
                utt_can_continue = "must"
564
                utt_conf = 1.0
565
                found_topic = topic
566
                first_utt = True
567
            elif re.findall(pattern, user_uttr["text"]) and not found_topic:
568
                utt_can_continue = "prompt"
569
                utt_conf = 0.95
570
                found_topic = topic
571
                first_utt = True
572
        switch_on = topic_config[topic].get("switch_on", [])
573
        for switch_elem in switch_on:
574
            condition = switch_elem["cond"]
575
            if check_condition(condition, user_uttr, bot_uttr, shared_memory):
576
                found_topic = topic
577
                utt_can_continue = switch_elem.get("can_continue", "can")
578
                utt_conf = switch_elem.get("conf", utt_conf)
579
                first_utt = True
580
                break
581
        if found_topic:
582
            break
583
    return found_topic, first_utt, utt_can_continue, utt_conf
584

585

586
def start_or_continue_scenario(vars, topic_config):
587
    flag = False
588
    bot_uttr = state_utils.get_last_bot_utterance(vars)
589
    prev_active_skill = bot_uttr.get("active_skill", "")
590
    shared_memory = state_utils.get_shared_memory(vars)
591
    isno = is_no(state_utils.get_last_human_utterance(vars))
592
    cur_mode = shared_memory.get("cur_mode", "smalltalk")
593
    found_topic = shared_memory.get("special_topic", "")
594
    logger.info(f"special_topic_request, found_topic {found_topic}")
595
    user_info = shared_memory.get("user_info", {})
596
    entity_triplets = shared_memory.get("entity_triplets", {})
597
    logger.info(f"start_or_continue_scenario, user_info {user_info}, entity_triplets {entity_triplets}")
598
    if cur_mode == "facts" and isno:
599
        cur_mode = "smalltalk"
600
    first_utt = False
601
    if not found_topic or prev_active_skill not in {"dff_wiki_skill", "dff_music_skill"}:
602
        found_topic, first_utt, utt_can_continue, utt_conf = check_switch(vars, topic_config)
603
        logger.info(f"start_or_continue_scenario, {found_topic}, {first_utt}")
604
    if found_topic:
605
        cur_topic_smalltalk = topic_config[found_topic].get("smalltalk", [])
606
        used_utt_nums = shared_memory.get("used_utt_nums", {}).get("found_topic", [])
607
        logger.info(f"used_smalltalk {used_utt_nums}")
608
        if cur_topic_smalltalk and len(used_utt_nums) < len(cur_topic_smalltalk) and cur_mode == "smalltalk":
609
            flag = True
610
        if not first_utt and (
611
            (found_topic != "music" and prev_active_skill != "dff_wiki_skill")
612
            or (found_topic == "music" and prev_active_skill != "dff_music_skill")
613
        ):
614
            flag = False
615
    return flag
616

617

618
def make_smalltalk_response(vars, topic_config, shared_memory, utt_info, used_utt_nums, num):
619
    user_uttr = state_utils.get_last_human_utterance(vars)
620
    bot_uttr = state_utils.get_last_bot_utterance(vars)
621
    response = ""
622
    utt_list = utt_info["utt"]
623
    found_ackn = ""
624
    ackns = utt_info.get("ackn", [])
625
    for ackn in ackns:
626
        condition = ackn["cond"]
627
        if check_condition(condition, user_uttr, bot_uttr, shared_memory):
628
            found_ackn = ackn["answer"]
629
            break
630
    found_prev_ackn = ""
631
    ackns = shared_memory.get("ackn", [])
632
    for ackn in ackns:
633
        condition = ackn["cond"]
634
        if check_condition(condition, user_uttr, bot_uttr, shared_memory):
635
            found_prev_ackn = ackn["answer"]
636
            break
637
    found_ackn = found_ackn or found_prev_ackn
638
    resp_list = make_resp_list(vars, utt_list, topic_config, shared_memory)
639
    if resp_list:
640
        response = " ".join(resp_list).strip().replace("  ", " ")
641
        used_utt_nums.append(num)
642
        cur_facts = utt_info.get("facts", {})
643
        state_utils.save_to_shared_memory(vars, cur_facts=cur_facts)
644
        next_ackn = utt_info.get("next_ackn", [])
645
        state_utils.save_to_shared_memory(vars, ackn=next_ackn)
646
        expected_entities = utt_info.get("expected_entities", {})
647
        if expected_entities:
648
            state_utils.save_to_shared_memory(vars, expected_entities=expected_entities)
649
        expected_subtopic_info = utt_info.get("expected_subtopic_info", {})
650
        logger.info(f"print expected_subtopic_info {expected_subtopic_info} utt_info {utt_info}")
651
        state_utils.save_to_shared_memory(vars, expected_subtopic_info=expected_subtopic_info)
652
        if found_ackn:
653
            found_ackn_sentences = nltk.sent_tokenize(found_ackn)
654
            found_ackn_list = make_resp_list(vars, found_ackn_sentences, topic_config, shared_memory)
655
            found_ackn = " ".join(found_ackn_list)
656
        response = f"{found_ackn} {response}".strip().replace("  ", " ")
657
    return response, used_utt_nums
658

659

660
def smalltalk_response(vars, topic_config):
661
    response = ""
662
    first_utt = False
663
    shared_memory = state_utils.get_shared_memory(vars)
664
    bot_uttr = state_utils.get_last_bot_utterance(vars)
665
    prev_active_skill = bot_uttr.get("active_skill", "")
666
    if prev_active_skill not in {"dff_wiki_skill", "dff_music_skill"}:
667
        delete_topic_info(vars)
668
    found_topic = shared_memory.get("special_topic", "")
669
    cur_mode = shared_memory.get("cur_mode", "smalltalk")
670
    isno = is_no(state_utils.get_last_human_utterance(vars))
671
    utt_can_continue = "can"
672
    utt_conf = 0.0
673
    if cur_mode == "facts" and isno:
674
        state_utils.save_to_shared_memory(vars, cur_wikihow_page="")
675
        state_utils.save_to_shared_memory(vars, cur_wikipedia_page="")
676
        memory["wikihow_content"] = []
677
        memory["wikipedia_content"] = []
678
    if not found_topic:
679
        found_topic, first_utt, utt_can_continue, utt_conf = check_switch(vars, topic_config)
680
    if found_topic:
681
        expected_entities = topic_config[found_topic].get("expected_entities", {})
682
        if expected_entities:
683
            state_utils.save_to_shared_memory(vars, expected_entities=expected_entities)
684
        existing_subtopic_info = shared_memory.get("expected_subtopic_info", [])
685
        expected_subtopic_info = topic_config[found_topic].get("expected_subtopic_info", {})
686
        if expected_subtopic_info and not existing_subtopic_info and first_utt:
687
            state_utils.save_to_shared_memory(vars, expected_subtopic_info=expected_subtopic_info)
688

689
    extract_and_save_entity(vars, topic_config, found_topic)
690
    extract_and_save_subtopic(vars, topic_config, found_topic)
691
    available_utterances = shared_memory.get("available_utterances", [])
692
    logger.info(f"subtopics {shared_memory.get('subtopics', [])}")
693
    subtopics_to_delete = 0
694
    add_general_ackn = False
695
    if found_topic:
696
        used_utt_nums_dict = shared_memory.get("used_utt_nums", {})
697
        used_utt_nums = used_utt_nums_dict.get(found_topic, [])
698
        state_utils.save_to_shared_memory(vars, special_topic=found_topic)
699
        subtopics = shared_memory.get("subtopics", [])
700
        if subtopics:
701
            for i in range(len(subtopics) - 1, -1, -1):
702
                cur_subtopic = subtopics[i]
703
                for num, utt_info in enumerate(topic_config[found_topic]["smalltalk"]):
704
                    utt_key = utt_info.get("key", "")
705
                    if num not in used_utt_nums and (
706
                        not available_utterances or (available_utterances and utt_key in available_utterances)
707
                    ):
708
                        if utt_info.get("subtopic", "") == cur_subtopic and check_utt_cases(vars, utt_info):
709
                            response, used_utt_nums = make_smalltalk_response(
710
                                vars, topic_config, shared_memory, utt_info, used_utt_nums, num
711
                            )
712
                            if response:
713
                                add_general_ackn = utt_info.get("add_general_ackn", False)
714
                                utt_can_continue = utt_info.get("can_continue", "can")
715
                                utt_conf = utt_info.get("conf", utt_conf)
716
                                break
717
                if response:
718
                    used_utt_nums_dict[found_topic] = used_utt_nums
719
                    state_utils.save_to_shared_memory(vars, used_utt_nums=used_utt_nums_dict)
720
                    if check_used_subtopic_utt(vars, topic_config, cur_subtopic):
721
                        subtopics_to_delete += 1
722
                    break
723
                else:
724
                    subtopics_to_delete += 1
725
        if not subtopics or not response:
726
            for num, utt_info in enumerate(topic_config[found_topic]["smalltalk"]):
727
                utt_key = utt_info.get("key", "")
728
                if (
729
                    num not in used_utt_nums
730
                    and check_utt_cases(vars, utt_info)
731
                    and not utt_info.get("subtopic", "")
732
                    and (not available_utterances or (available_utterances and utt_key in available_utterances))
733
                ):
734
                    response, used_utt_nums = make_smalltalk_response(
735
                        vars, topic_config, shared_memory, utt_info, used_utt_nums, num
736
                    )
737
                    if response:
738
                        utt_can_continue = utt_info.get("can_continue", "can")
739
                        utt_conf = utt_info.get("conf", utt_conf)
740
                        add_general_ackn = utt_info.get("add_general_ackn", False)
741
                        used_utt_nums_dict[found_topic] = used_utt_nums
742
                        state_utils.save_to_shared_memory(vars, used_utt_nums=used_utt_nums_dict)
743
                        break
744
        if subtopics_to_delete:
745
            for i in range(subtopics_to_delete):
746
                subtopics.pop()
747
            state_utils.save_to_shared_memory(vars, subtopics=subtopics)
748

749
        logger.info(f"used_utt_nums_dict {used_utt_nums_dict} used_utt_nums {used_utt_nums}")
750
    acknowledgement = check_acknowledgements(vars, topic_config)
751
    answer = answer_users_question(vars, topic_config) or acknowledgement
752
    response = f"{answer} {response}".strip().replace("  ", " ")
753
    logger.info(f"response {response}")
754
    if response:
755
        state_utils.save_to_shared_memory(vars, cur_mode="smalltalk")
756
        if utt_conf > 0.0:
757
            state_utils.set_confidence(vars, confidence=utt_conf)
758
        else:
759
            state_utils.set_confidence(vars, confidence=CONF_DICT["WIKI_TOPIC"])
760
        if first_utt or utt_can_continue == "must":
761
            state_utils.set_can_continue(vars, continue_flag=common_constants.MUST_CONTINUE)
762
        elif utt_can_continue == "prompt":
763
            state_utils.set_can_continue(vars, continue_flag=common_constants.CAN_CONTINUE_PROMPT)
764
        else:
765
            state_utils.set_can_continue(vars, continue_flag=common_constants.CAN_CONTINUE_SCENARIO)
766
    else:
767
        state_utils.set_confidence(vars, confidence=CONF_DICT["UNDEFINED"])
768
        state_utils.set_can_continue(vars, continue_flag=common_constants.CAN_NOT_CONTINUE)
769
    if not add_general_ackn:
770
        state_utils.add_acknowledgement_to_response_parts(vars)
771
    return response
772

773

774
def start_or_continue_facts(vars, topic_config):
775
    flag = False
776
    shared_memory = state_utils.get_shared_memory(vars)
777
    bot_uttr = state_utils.get_last_bot_utterance(vars)
778
    prev_active_skill = bot_uttr.get("active_skill", "")
779
    isno = is_no(state_utils.get_last_human_utterance(vars))
780
    found_topic = shared_memory.get("special_topic", "")
781
    cur_mode = shared_memory.get("cur_mode", "smalltalk")
782
    cur_wikipedia_page = shared_memory.get("cur_wikipedia_page", "")
783
    cur_wikihow_page = shared_memory.get("cur_wikihow_page", "")
784
    logger.info(f"cur_wikihow_page {cur_wikihow_page} cur_wikipedia_page {cur_wikipedia_page}")
785
    if found_topic:
786
        if cur_mode == "smalltalk" and "triggers" in topic_config[found_topic]:
787
            triggers = topic_config[found_topic]["triggers"]
788
            entity_substr, entity_types, wikipedia_page, wikihow_page = find_trigger(vars, triggers)
789
            if wikihow_page or wikipedia_page or if_facts_agree(vars):
790
                flag = True
791
        else:
792
            checked_wikipage = extract_and_save_wikipage(vars)
793
            if checked_wikipage:
794
                flag = True
795
            if (cur_wikipedia_page or cur_wikihow_page) and not isno:
796
                wikihow_page_content_list = memory.get("wikihow_content", [])
797
                wikipedia_page_content_list = memory.get("wikipedia_content", [])
798
                used_wikihow_nums = shared_memory.get("used_wikihow_nums", {}).get(cur_wikihow_page, [])
799
                used_wikipedia_nums = shared_memory.get("used_wikipedia_nums", {}).get(cur_wikipedia_page, [])
800
                logger.info(f"request, used_wikihow_nums {used_wikihow_nums} used_wikipedia_nums {used_wikipedia_nums}")
801
                logger.info(
802
                    f"request, wikipedia_page_content_list {wikipedia_page_content_list[:3]} "
803
                    f"wikihow_page_content_list {wikihow_page_content_list[:3]}"
804
                )
805
                if len(wikihow_page_content_list) > 0 and len(used_wikihow_nums) < len(wikihow_page_content_list):
806
                    flag = True
807
                if len(wikipedia_page_content_list) > 0 and len(used_wikipedia_nums) < len(wikipedia_page_content_list):
808
                    flag = True
809

810
    first_utt = False
811
    if not shared_memory.get("special_topic", "") or prev_active_skill not in {"dff_wiki_skill", "dff_music_skill"}:
812
        found_topic, first_utt, utt_can_continue, utt_conf = check_switch(vars, topic_config)
813
    logger.info(f"start_or_continue_facts, first_utt {first_utt}")
814
    if found_topic:
815
        facts = topic_config[found_topic].get("facts", {})
816
        if facts:
817
            flag = True
818
        if not first_utt and (
819
            (found_topic != "music" and prev_active_skill != "dff_wiki_skill")
820
            or (found_topic == "music" and prev_active_skill != "dff_music_skill")
821
        ):
822
            flag = False
823
    return flag
824

825

826
def facts_response(vars, topic_config, wikihow_cache, wikipedia_cache):
827
    shared_memory = state_utils.get_shared_memory(vars)
828
    user_uttr = state_utils.get_last_human_utterance(vars)
829
    bot_uttr = state_utils.get_last_bot_utterance(vars)
830
    prev_active_skill = bot_uttr.get("active_skill", "")
831
    if prev_active_skill not in {"dff_wiki_skill", "dff_music_skill"}:
832
        delete_topic_info(vars)
833
    isyes = is_yes(user_uttr) or re.findall(CONTINUE_PATTERN, user_uttr["text"])
834
    response = ""
835
    cur_mode = shared_memory.get("cur_mode", "smalltalk")
836
    wikipedia_page = shared_memory.get("cur_wikipedia_page", "")
837
    wikihow_page = shared_memory.get("cur_wikihow_page", "")
838
    found_topic = shared_memory.get("special_topic", "")
839
    utt_can_continue = common_constants.CAN_CONTINUE_SCENARIO
840
    utt_conf = CONF_DICT["WIKI_TOPIC"]
841
    first_utt = False
842
    entity_substr = ""
843
    entity_types = []
844
    if not found_topic:
845
        found_topic, first_utt, utt_can_continue, utt_conf = check_switch(vars, topic_config)
846
    extract_and_save_entity(vars, topic_config, found_topic)
847
    extract_and_save_subtopic(vars, topic_config, found_topic)
848
    extract_and_save_wikipage(vars, True)
849
    if found_topic and cur_mode == "smalltalk":
850
        if "triggers" in topic_config[found_topic]:
851
            triggers = topic_config[found_topic]["triggers"]
852
            entity_substr, entity_types, wikipedia_page, wikihow_page = find_trigger(vars, triggers)
853
        facts = topic_config[found_topic].get("facts", {})
854
        if facts and not wikihow_page and not wikipedia_page:
855
            entity_substr = facts.get("entity_substr", "")
856
            entity_types = facts.get("entity_types", [])
857
            wikihow_page = facts.get("wikihow_page", "")
858
            wikipedia_page = facts.get("wikipedia_page", "")
859
            logger.info(f"wikipedia_page {wikipedia_page}")
860
        if not wikihow_page:
861
            wikihow_page = shared_memory.get("cur_wikihow_page", "")
862
        if wikihow_page:
863
            if wikihow_page in wikihow_cache:
864
                page_content = wikihow_cache[wikihow_page]
865
            else:
866
                page_content = get_wikihow_content(wikihow_page)
867
            wikihow_page_content_list = preprocess_wikihow_page(page_content)
868
            memory["wikihow_content"] = wikihow_page_content_list
869
            state_utils.save_to_shared_memory(vars, cur_wikihow_page=wikihow_page)
870
        if not wikipedia_page:
871
            wikipedia_page = shared_memory.get("cur_wikipedia_page", "")
872

873
        if wikipedia_page:
874
            if wikipedia_page in wikipedia_cache:
875
                page_content = wikipedia_cache[wikipedia_page].get("page_content", {})
876
            else:
877
                page_content, _ = get_page_content(wikipedia_page)
878
            if not entity_substr:
879
                entity_substr = wikipedia_page.lower()
880
            titles_info = topic_config[found_topic].get("titles_info", [])
881
            predefined_titles = []
882
            for titles_info_elem in titles_info:
883
                if wikipedia_page in titles_info_elem["pages"]:
884
                    predefined_titles = titles_info_elem["titles"]
885
                    break
886
            wikipedia_page_content_list = preprocess_wikipedia_page(
887
                entity_substr, entity_types, page_content, predefined_titles
888
            )
889
            memory["wikipedia_content"] = wikipedia_page_content_list
890
            state_utils.save_to_shared_memory(vars, cur_wikipedia_page=wikipedia_page)
891
        logger.info(f"wikihow_page {wikihow_page} wikipedia_page {wikipedia_page}")
892
    if found_topic:
893
        used_wikihow_nums_dict = shared_memory.get("used_wikihow_nums", {})
894
        used_wikihow_nums = used_wikihow_nums_dict.get(wikihow_page, [])
895
        used_wikipedia_nums_dict = shared_memory.get("used_wikipedia_nums", {})
896
        used_wikipedia_nums = used_wikipedia_nums_dict.get(wikipedia_page, [])
897
        wikihow_page_content_list = memory.get("wikihow_content", [])
898
        wikipedia_page_content_list = memory.get("wikipedia_content", [])
899
        logger.info(f"response, used_wikihow_nums {used_wikihow_nums} used_wikipedia_nums {used_wikipedia_nums}")
900
        logger.info(
901
            f"response, wikipedia_page_content_list {wikipedia_page_content_list[:3]} "
902
            f"wikihow_page_content_list {wikihow_page_content_list[:3]}"
903
        )
904
        if wikihow_page and wikihow_page_content_list:
905
            for num, fact in enumerate(wikihow_page_content_list):
906
                if num not in used_wikihow_nums:
907
                    facts_str = fact.get("facts_str", "")
908
                    question = fact.get("question", "")
909
                    response = f"{facts_str} {question}".strip().replace("  ", " ")
910
                    used_wikihow_nums.append(num)
911
                    used_wikihow_nums_dict[wikihow_page] = used_wikihow_nums
912
                    state_utils.save_to_shared_memory(vars, used_wikihow_nums=used_wikihow_nums_dict)
913
                    break
914
        if not response and wikipedia_page and wikipedia_page_content_list:
915
            for num, fact in enumerate(wikipedia_page_content_list):
916
                if num not in used_wikipedia_nums:
917
                    facts_str = fact.get("facts_str", "")
918
                    question = fact.get("question", "")
919
                    response = f"{facts_str} {question}".strip().replace("  ", " ")
920
                    used_wikipedia_nums.append(num)
921
                    used_wikipedia_nums_dict[wikipedia_page] = used_wikipedia_nums
922
                    state_utils.save_to_shared_memory(vars, used_wikipedia_nums=used_wikipedia_nums_dict)
923
                    break
924
        cur_mode = "facts"
925
        if len(wikihow_page_content_list) == len(used_wikihow_nums) and len(wikipedia_page_content_list) == len(
926
            used_wikipedia_nums
927
        ):
928
            cur_mode = "smalltalk"
929
            if len(wikihow_page_content_list) == len(used_wikihow_nums):
930
                state_utils.save_to_shared_memory(vars, cur_wikihow_page="")
931
                memory["wikihow_content"] = []
932
            if len(wikipedia_page_content_list) == len(used_wikipedia_nums):
933
                state_utils.save_to_shared_memory(vars, cur_wikipedia_page="")
934
                memory["wikipedia_content"] = []
935

936
    answer = answer_users_question(vars, topic_config)
937
    response = f"{answer} {response}".strip().replace("  ", " ")
938
    if not shared_memory.get("special_topic", ""):
939
        found_topic, first_utt, utt_can_continue, utt_conf = check_switch(vars, topic_config)
940
        state_utils.save_to_shared_memory(vars, special_topic=found_topic)
941
    if response:
942
        state_utils.save_to_shared_memory(vars, cur_mode=cur_mode)
943
        state_utils.set_confidence(vars, confidence=utt_conf)
944
        if isyes or (first_utt and utt_can_continue == "must"):
945
            state_utils.set_can_continue(vars, continue_flag=common_constants.MUST_CONTINUE)
946
        else:
947
            state_utils.set_can_continue(vars, continue_flag=utt_can_continue)
948
    else:
949
        state_utils.set_confidence(vars, confidence=CONF_DICT["UNDEFINED"])
950
        state_utils.set_can_continue(vars, continue_flag=common_constants.CAN_NOT_CONTINUE)
951
    state_utils.add_acknowledgement_to_response_parts(vars)
952
    return response
953

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.