dream

Форк
0
/
postprocessor.py 
292 строки · 10.9 Кб
1
from typing import Sequence, List, Tuple, Callable, Dict
2
import random
3
import itertools
4
import copy
5
import re
6

7
from core.state_schema import Dialog
8

9

10
def detokenize(tokens):
11
    """
12
    Detokenizing a text undoes the tokenizing operation, restores
13
    punctuation and spaces to the places that people expect them to be.
14
    Ideally, `detokenize(tokenize(text))` should be identical to `text`,
15
    except for line breaks.
16
    """
17
    text = " ".join(tokens)
18
    step0 = text.replace(". . .", "...")
19
    step1 = step0.replace("`` ", '"').replace(" ''", '"')
20
    step2 = step1.replace(" ( ", " (").replace(" ) ", ") ")
21
    step3 = re.sub(r' ([.,:;?!%]+)([ \'"`])', r"\1\2", step2)
22
    step4 = re.sub(r" ([.,:;?!%]+)$", r"\1", step3)
23
    step5 = step4.replace(" '", "'").replace(" n't", "n't").replace(" nt", "nt").replace("can not", "cannot")
24
    step6 = step5.replace(" ` ", " '")
25
    return step6.strip()
26

27

28
class PersonNormalizer:
29
    """
30
    Detects mentions of mate user's name and either
31
    (0) converts them to user's name taken from state
32
    (1) either removes them.
33

34
    Parameters:
35
        person_tag: tag name that corresponds to a person entity
36
    """
37

38
    def __init__(self, person_tag: str = "PER", **kwargs):
39
        self.per_tag = person_tag
40

41
    def __call__(
42
        self, tokens: List[List[str]], tags: List[List[str]], names: List[str]
43
    ) -> Tuple[List[List[str]], List[List[str]]]:
44
        out_tokens, out_tags = [], []
45
        for u_name, u_toks, u_tags in zip(names, tokens, tags):
46
            u_toks, u_tags = self.tag_mate_gooser_name(u_toks, u_tags, person_tag=self.per_tag)
47
            if u_name:
48
                u_toks, u_tags = self.replace_mate_gooser_name(u_toks, u_tags, u_name)
49
                if random.random() < 0.5:
50
                    u_toks = [u_name, ","] + u_toks
51
                    u_tags = ["B-MATE-GOOSER", "O"] + u_tags
52

53
                    u_toks[0] = u_toks[0][0].upper() + u_toks[0][1:]
54
                    if u_tags[2] == "O":
55
                        u_toks[2] = u_toks[2][0].lower() + u_toks[2][1:]
56
            else:
57
                u_toks, u_tags = self.remove_mate_gooser_name(u_toks, u_tags)
58
            out_tokens.append(u_toks)
59
            out_tags.append(u_tags)
60
        return out_tokens, out_tags
61

62
    @staticmethod
63
    def tag_mate_gooser_name(
64
        tokens: List[str], tags: List[str], person_tag: str = "PER", mate_tag: str = "MATE-GOOSER"
65
    ) -> Tuple[List[str], List[str]]:
66
        if "B-" + person_tag not in tags:
67
            return tokens, tags
68
        out_tags = []
69
        i = 0
70
        while i < len(tokens):
71
            tok, tag = tokens[i], tags[i]
72
            if i + 1 < len(tokens):
73
                if (tok == ",") and (tags[i + 1] == "B-" + person_tag):
74
                    # it might be mate gooser name
75
                    out_tags.append(tag)
76
                    j = 1
77
                    while (i + j < len(tokens)) and (tags[i + j][2:] == person_tag):
78
                        j += 1
79
                    if (i + j == len(tokens)) or (tokens[i + j][0] in ",.!?;)"):
80
                        # it is mate gooser
81
                        out_tags.extend([t[:2] + mate_tag for t in tags[i + 1 : i + j]])
82
                    else:
83
                        # it isn't
84
                        out_tags.extend(tags[i + 1 : i + j])
85
                    i += j
86
                    continue
87
            if i > 0:
88
                if (tok == ",") and (tags[i - 1][2:] == person_tag):
89
                    # it might have been mate gooser name
90
                    j = 1
91
                    while (len(out_tags) >= j) and (out_tags[-j][2:] == person_tag):
92
                        j += 1
93
                    if (len(out_tags) < j) or (tokens[i - j][-1] in ",.!?("):
94
                        # it was mate gooser
95
                        for k in range(j - 1):
96
                            out_tags[-k - 1] = out_tags[-k - 1][:2] + mate_tag
97
                    out_tags.append(tag)
98
                    i += 1
99
                    continue
100
            out_tags.append(tag)
101
            i += 1
102
        return tokens, out_tags
103

104
    @staticmethod
105
    def replace_mate_gooser_name(
106
        tokens: List[str], tags: List[str], replacement: str, mate_tag: str = "MATE-GOOSER"
107
    ) -> Tuple[List[str], List[str]]:
108
        assert len(tokens) == len(tags), f"tokens({tokens}) and tags({tags}) should have the same length"
109
        if "B-" + mate_tag not in tags:
110
            return tokens, tags
111

112
        repl_tokens = replacement.split()
113
        repl_tags = ["B-" + mate_tag] + ["I-" + mate_tag] * (len(repl_tokens) - 1)
114

115
        out_tokens, out_tags = [], []
116
        i = 0
117
        while i < len(tokens):
118
            tok, tag = tokens[i], tags[i]
119
            if tag == "B-" + mate_tag:
120
                out_tokens.extend(repl_tokens)
121
                out_tags.extend(repl_tags)
122
                i += 1
123
                while (i < len(tokens)) and (tokens[i] == "I-" + mate_tag):
124
                    i += 1
125
            else:
126
                out_tokens.append(tok)
127
                out_tags.append(tag)
128
                i += 1
129
        return out_tokens, out_tags
130

131
    @staticmethod
132
    def remove_mate_gooser_name(
133
        tokens: List[str], tags: List[str], mate_tag: str = "MATE-GOOSER"
134
    ) -> Tuple[List[str], List[str]]:
135
        assert len(tokens) == len(tags), f"tokens({tokens}) and tags({tags}) should have the same length"
136
        # TODO: uppercase first letter if name was removed
137
        if "B-" + mate_tag not in tags:
138
            return tokens, tags
139
        out_tokens, out_tags = [], []
140
        i = 0
141
        while i < len(tokens):
142
            tok, tag = tokens[i], tags[i]
143
            if i + 1 < len(tokens):
144
                if (tok == ",") and (tags[i + 1] == "B-" + mate_tag):
145
                    # it will be mate gooser name next, skip comma
146
                    i += 1
147
                    continue
148
            if i > 0:
149
                if (tok == ",") and (tags[i - 1][2:] == mate_tag):
150
                    # that was mate gooser name, skip comma
151
                    i += 1
152
                    continue
153
            if tag[2:] != mate_tag:
154
                out_tokens.append(tok)
155
                out_tags.append(tag)
156
            i += 1
157
        return out_tokens, out_tags
158

159

160
LIST_LIST_STR_BATCH = List[List[List[str]]]
161

162

163
class HistoryPersonNormalize:
164
    """
165
    Takes batch of dialog histories and normalizes only bot responses.
166

167
    Detects mentions of mate user's name and either
168
    (0) converts them to user's name taken from state
169
    (1) either removes them.
170

171
    Parameters:
172
        per_tag: tag name that corresponds to a person entity
173
    """
174

175
    def __init__(self, per_tag: str = "PER", **kwargs):
176
        self.per_normalizer = PersonNormalizer(per_tag=per_tag)
177

178
    def __call__(
179
        self, history_tokens: LIST_LIST_STR_BATCH, tags: LIST_LIST_STR_BATCH, states: List[Dict]
180
    ) -> Tuple[LIST_LIST_STR_BATCH, LIST_LIST_STR_BATCH]:
181
        out_tokens, out_tags = [], []
182
        states = states if states else [{}] * len(tags)
183
        for u_state, u_hist_tokens, u_hist_tags in zip(states, history_tokens, tags):
184
            # TODO: normalize bot response history
185
            pass
186
        return out_tokens, out_tags
187

188

189
class MyselfDetector:
190
    """
191
    Finds first mention of a name and sets it as a user name.
192

193
    Parameters:
194
        person_tag: tag name that corresponds to a person entity
195
        state_slot: name of a state slot corresponding to a user's name
196

197
    """
198

199
    def __init__(self, person_tag: str = "PER", **kwargs):
200
        self.per_tag = person_tag
201

202
    def __call__(self, tokens: List[List[str]], tags: List[List[str]], states: List[dict]) -> List[str]:
203
        names = []
204
        for u_state, u_toks, u_tags in zip(states, tokens, tags):
205
            cur_name = u_state["user"]["profile"]["name"]
206
            new_name = copy(cur_name)
207
            if not cur_name:
208
                name_found = self.find_my_name(u_toks, u_tags, person_tag=self.per_tag)
209
                if name_found is not None:
210
                    new_name = name_found
211
            names.append(new_name)
212
        return names
213

214
    @staticmethod
215
    def find_my_name(tokens: List[str], tags: List[str], person_tag: str) -> str:
216
        if "B-" + person_tag not in tags:
217
            return None
218
        per_start = tags.index("B-" + person_tag)
219
        per_excl_end = per_start + 1
220
        while (per_excl_end < len(tokens)) and (tags[per_excl_end] == "I-" + person_tag):
221
            per_excl_end += 1
222
        return " ".join(tokens[per_start:per_excl_end])
223

224

225
class NerWithContextWrapper:
226
    """
227
    Tokenizers utterance and history of dialogue and gets entity tags for
228
    utterance's tokens.
229

230
    Parameters:
231
        ner_model: named entity recognition model
232
        tokenizer: tokenizer to use
233

234
    """
235

236
    def __init__(self, ner_model: Callable, tokenizer: Callable, context_delimeter: str = None, **kwargs):
237
        self.ner_model = ner_model
238
        self.tokenizer = tokenizer
239
        self.context_delimeter = context_delimeter
240

241
    def __call__(
242
        self, utterances: List[str], history: List[List[str]] = [[]], prev_utterances: List[str] = []
243
    ) -> Tuple[List[List[str]], List[List[str]]]:
244
        if prev_utterances:
245
            history = history or itertools.repeat([])
246
            history = [hist + [prev] for prev, hist in zip(prev_utterances, history)]
247
        history_toks = [
248
            [tok for toks in self.tokenizer(hist or [""]) for tok in toks + [self.context_delimeter] if tok is not None]
249
            for hist in history
250
        ]
251
        utt_toks = self.tokenizer(utterances)
252
        texts, ranges = [], []
253
        for utt, hist in zip(utt_toks, history_toks):
254
            if self.context_delimeter is not None:
255
                txt = hist + utt + [self.context_delimeter]
256
            else:
257
                txt = hist + utt
258
            ranges.append((len(hist), len(hist) + len(utt)))
259
            texts.append(txt)
260

261
        _, tags = self.ner_model(texts)
262
        tags = [t[l:r] for t, (l, r) in zip(tags, ranges)]
263

264
        return utt_toks, tags
265

266

267
class DefaultPostprocessor:
268
    def __init__(self) -> None:
269
        self.person_normalizer = PersonNormalizer(per_tag="PER")
270

271
    def __call__(self, dialogs: Sequence[Dialog]) -> Sequence[str]:
272
        new_responses = []
273
        for d in dialogs:
274
            # get tokens & tags
275
            response = d["utterances"][-1]
276
            try:
277
                ner_annotations = response["annotations"].get("ner", {})
278
                user_name = d["user"]["profile"]["name"]
279
                # replace names with user name
280
                if ner_annotations and (response["active_skill"] == "chitchat"):
281
                    response_toks_norm, _ = self.person_normalizer(
282
                        [ner_annotations["tokens"]], [ner_annotations["tags"]], [user_name]
283
                    )
284
                    response_toks_norm = response_toks_norm[0]
285
                    # detokenize
286
                    new_responses.append(detokenize(response_toks_norm))
287
                else:
288
                    new_responses.append(response["text"])
289
            except KeyError:
290
                new_responses.append(response["text"])
291

292
        return new_responses
293

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.