dream

Форк
0
/
entity_utils.py 
240 строк · 9.6 Кб
1
from typing import List
2
import logging
3
import os
4
import collections
5

6

7
import en_core_web_sm
8
import sentry_sdk
9
from nltk.stem import WordNetLemmatizer
10

11
from common.utils import get_entities
12
from common.universal_templates import get_entities_with_attitudes
13

14
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"))
15

16
ENCOUNTERS_MAX_LEN = 3
17
ENTITY_MAX_NO = 10
18

19
logger = logging.getLogger(__name__)
20

21
spacy_nlp = en_core_web_sm.load()
22

23
wnl = WordNetLemmatizer()
24

25

26
class HumanEntityEncounter:
27
    def __init__(self, human_utterance_index: int, full_name: List[str], previous_skill_name: str = "", **kwargs):
28
        self.human_utterance_index = human_utterance_index
29
        self.full_name = full_name
30
        self.previous_skill_name = previous_skill_name
31

32
    def __iter__(self):
33
        for x, y in self.__dict__.items():
34
            yield x, y
35

36

37
class BotEntityEncounter:
38
    def __init__(self, human_utterance_index: int, full_name: str, skill_name: str, **kwargs):
39
        self.human_utterance_index = human_utterance_index
40
        self.full_name = full_name
41
        self.skill_name = skill_name
42

43
    def __iter__(self):
44
        for x, y in self.__dict__.items():
45
            yield x, y
46

47

48
class Entity:
49
    def __init__(self, name=None, raw_data=None):
50
        if name:
51
            self.name = name
52
            self.human_encounters = []
53
            self.bot_encounters = []
54
            self.human_attitude = None
55
            self.bot_attitude = None
56
        else:
57
            try:
58
                assert isinstance(raw_data, dict)
59
                assert isinstance(raw_data["name"], str)
60
                assert isinstance(raw_data["human_encounters"], list)
61
                assert isinstance(raw_data["bot_encounters"], list)
62
                self.name = raw_data["name"]
63
                self.human_encounters = [
64
                    HumanEntityEncounter(**encounter) for encounter in raw_data["human_encounters"]
65
                ]
66
                self.bot_encounters = [BotEntityEncounter(**encounter) for encounter in raw_data["bot_encounters"]]
67
                self.human_attitude = raw_data.get("human_attitude", None)
68
                self.bot_attitude = raw_data.get("bot_attitude", None)
69
            except Exception as exc:
70
                logger.exception(exc)
71
                sentry_sdk.capture_exception(exc)
72
                self.name = "#LOAD_ENTITY_ERROR"
73
                self.human_encounters = []
74
                self.bot_encounters = []
75

76
        self.human_encounters = collections.deque(self.human_encounters, maxlen=ENCOUNTERS_MAX_LEN)
77
        self.bot_encounters = collections.deque(self.bot_encounters, maxlen=ENCOUNTERS_MAX_LEN)
78

79
    def __iter__(self):
80
        for x, y in self.__dict__.items():
81
            if x in ["human_encounters", "bot_encounters"]:
82
                yield x, [dict(i) for i in y]
83
            else:
84
                yield x, y
85

86
    def add_human_attitude(self, attitude):
87
        self.human_attitude = attitude
88

89
    def add_bot_attitude(self, attitude):
90
        self.bot_attitude = attitude
91

92
    def add_human_encounters(self, human_utters, bot_utters, human_utter_index):
93
        human_utter = human_utters[-1]
94
        bot_utter = bot_utters[0] if bot_utters else {}
95
        entities = get_entities(human_utter, only_named=False, with_labels=False)
96
        entities = [ent for ent in entities if self.name in wnl.lemmatize(ent, "n")]
97

98
        active_skill = bot_utter.get("active_skill", "pre_start")
99
        for entity in entities:
100
            hee = HumanEntityEncounter(
101
                human_utterance_index=human_utter_index,
102
                full_name=entity,
103
                previous_skill_name=active_skill,
104
            )
105
            self.human_encounters.append(hee)
106

107
    def add_bot_encounters(self, human_utters, bot_utters, human_utter_index):
108
        bot_utter = bot_utters[0] if bot_utters else {}
109
        entities = get_entities(bot_utter, only_named=False, with_labels=False)
110
        entities = [ent for ent in entities if self.name in wnl.lemmatize(ent, "n")]
111

112
        active_skill = bot_utter.get("active_skill", "pre_start")
113
        for entity in entities:
114
            bee = BotEntityEncounter(
115
                human_utterance_index=human_utter_index,
116
                full_name=entity,
117
                skill_name=active_skill,
118
            )
119
            self.bot_encounters.append(bee)
120

121
    def update_human_encounters(self, human_utters, bot_utters, human_utter_index):
122
        bot_utter = bot_utters[0] if bot_utters else {}
123
        active_skill = bot_utter.get("active_skill", "pre_start" if len(human_utters) == 1 else "unknown")
124
        encounters = [
125
            encounter for encounter in self.human_encounters if human_utter_index - 1 == encounter.human_utterance_index
126
        ]
127
        for encounter in encounters:
128
            encounter.next_skill_name = active_skill
129

130
    def get_last_utterance_index(self):
131
        utterance_indexes = [
132
            encounter.human_utterance_index
133
            for encounter in list(self.human_encounters)[-1:] + list(self.bot_encounters)[-1:]
134
        ]
135
        max_utterance_index = max(utterance_indexes) if utterance_indexes else -1
136
        return max_utterance_index
137

138

139
def parse_entities_with_attitude(annotated_uttr: dict, prev_annotated_uttr: dict):
140
    entities_with_attitude = get_entities_with_attitudes(annotated_uttr, prev_annotated_uttr)
141
    entities_with_attitude = {
142
        "like": [wnl.lemmatize(ent, "n") for ent in entities_with_attitude["like"]],
143
        "dislike": [wnl.lemmatize(ent, "n") for ent in entities_with_attitude["dislike"]],
144
    }
145
    return entities_with_attitude
146

147

148
def load_raw_entities(raw_entities):
149
    entities = {entity_name: Entity(raw_data=entity_raw_data) for entity_name, entity_raw_data in raw_entities.items()}
150
    entities = {entity_name: ent for entity_name, ent in entities.items() if "_ERROR" not in ent.name}
151
    entities = {entity_name: ent for entity_name, ent in entities.items() if "." not in entity_name}
152
    entities = {entity_name: ent for entity_name, ent in entities.items() if len(entity_name) > 2}
153
    return entities
154

155

156
def update_entities(dialog, human_utter_index, entities=None):
157
    entities = {} if entities is None else entities
158
    old_entities = list(entities)
159
    human_utterances = dialog["human_utterances"]
160
    bot_utterances = dialog["bot_utterances"]
161

162
    # add/update bot entities
163
    if bot_utterances:
164
        bot_entities_with_attitude = parse_entities_with_attitude(
165
            bot_utterances[-1], human_utterances[-2] if len(human_utterances) > 1 else {}
166
        )
167
        for attitude in ["like", "dislike"]:
168
            bot_short_entities = bot_entities_with_attitude[attitude]
169
            bot_entities = {
170
                entity_name: entities.get(entity_name, Entity(entity_name)) for entity_name in bot_short_entities
171
            }
172
            entities.update(bot_entities)
173
            [
174
                ent.add_bot_encounters(human_utterances, bot_utterances, human_utter_index)
175
                for ent in bot_entities.values()
176
            ]
177
            [ent.add_bot_attitude(attitude) for ent in bot_entities.values()]
178

179
    # add/update human entities
180
    human_entities_with_attitude = parse_entities_with_attitude(
181
        human_utterances[-1], bot_utterances[-1] if len(bot_utterances) else {}
182
    )
183
    for attitude in ["like", "dislike"]:
184
        human_short_entities = human_entities_with_attitude[attitude]
185
        human_entities = {
186
            entity_name: entities.get(entity_name, Entity(entity_name)) for entity_name in human_short_entities
187
        }
188
        entities.update(human_entities)
189
        [
190
            ent.add_human_encounters(human_utterances, bot_utterances, human_utter_index)
191
            for ent in human_entities.values()
192
        ]
193
        [ent.add_human_attitude(attitude) for ent in human_entities.values()]
194

195
    # update previus human entities
196
    if len(human_utterances) == 2:
197
        human_entities_with_attitude = parse_entities_with_attitude(
198
            human_utterances[-2], bot_utterances[-2] if len(bot_utterances) > 1 else {}
199
        )
200
        for attitude in ["like", "dislike"]:
201
            short_human_entities = human_entities_with_attitude[attitude]
202
            new_human_entities = {
203
                entity_name: Entity(entity_name)
204
                for entity_name in short_human_entities
205
                if entity_name not in old_entities
206
            }
207
            entities.update(new_human_entities)
208
            [
209
                ent.add_human_encounters(human_utterances, [], human_utter_index - 1)
210
                for ent in new_human_entities.values()
211
            ]
212
            [ent.add_human_attitude(attitude) for ent in new_human_entities.values()]
213
            [
214
                entities[entity_name].update_human_encounters(human_utterances, bot_utterances, human_utter_index)
215
                for entity_name in short_human_entities
216
            ]
217
    index2entity = {
218
        entity.get_last_utterance_index(): {entity_name: entity} for entity_name, entity in entities.items()
219
    }
220
    recent_indexes = sorted(index2entity)[-ENTITY_MAX_NO:]
221
    entities = {}
222
    [entities.update(index2entity[i]) for i in recent_indexes]
223
    return entities
224

225

226
def get_new_human_entities(entities, human_utterance_index):
227
    entities = {
228
        entity_name: ent
229
        for entity_name, ent in entities.items()
230
        if len(ent.human_encounters) == 1 and ent.human_encounters[-1].human_utterance_index == human_utterance_index
231
    }
232
    return entities
233

234

235
def get_time_sorted_human_entities(entities):
236
    entities = {entity_name: ent for entity_name, ent in entities.items() if len(ent.human_encounters) == 1}
237
    sorted_entities = sorted(
238
        entities, key=lambda entity_name: entities[entity_name].human_encounters[-1].human_utterance_index
239
    )
240
    return {entity_name: entities[entity_name] for entity_name in sorted_entities}
241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.