dream

Форк
0
/
db_converter.py 
156 строк · 4.6 Кб
1
# %%
2

3
import tqdm
4
import pickle
5
import pathlib
6
import json
7
import difflib
8
import logging
9
import aiohttp
10
import asyncio
11
import re
12
import copy
13

14

15
logger = logging.getLogger(__name__)
16

17

18
# %%
19
data_dir = pathlib.Path("data")
20
work_dir = pathlib.Path("tmp")
21

22
banned_responses_file = data_dir / "banned_responses_v2.json"
23
db_file = work_dir / "replies_v2.pkl"
24
output_db_file = work_dir / "replies_v3.pkl"
25

26
chache_file = work_dir / "chache.json"
27
dropped_responses_file = work_dir / "dropped_responses.json"
28
log_file = work_dir / "logs.txt"
29

30
response_encodings, responses = pickle.load(db_file.open("rb"))
31
# %%
32

33
marked_responses = [{"text": res, "index": i} for i, res in enumerate(responses)]
34

35

36
# %%
37
format_1 = re.compile(r"\s+(,|\.|!|\?)")
38
format_2 = re.compile(r"\s+(’)\s+")
39

40

41
def format_text(text):
42
    text = re.sub(format_1, r"\1", text)
43
    text = re.sub(format_2, r"'", text)
44
    return text
45

46

47
other_symbols_compiled = re.compile(r"[^a-zA-Z0-9\- ]", re.IGNORECASE)
48
space_compiled = re.compile(r"\s+", re.IGNORECASE)
49

50

51
def cleanup(text):
52
    cleaned = re.sub(other_symbols_compiled, "", text)
53
    cleaned = re.sub(space_compiled, " ", cleaned)
54
    return cleaned.strip()
55

56

57
def get_data_iter(data, split_n):
58
    ranges = range(0, len(data) + split_n, split_n)
59
    for begin_i, end_i in zip(ranges, ranges[1:]):
60
        yield data[begin_i:end_i]
61

62

63
def is_same(ground_truth, hypothesis, ratio=0.9):
64
    res_ratio = difflib.SequenceMatcher(None, ground_truth.split(), hypothesis.split()).ratio()
65
    return res_ratio >= ratio
66

67

68
async def aio_request(session, url, samples):
69
    try:
70
        new_samples = []
71
        async with session.post(url, json={"sentences": [subsample["text"] for subsample in samples]}) as resp:
72
            batch = await resp.json()
73
            assert len(batch) == len(samples)
74
            for b_el, subsample in zip(batch, samples):
75
                punct_text = b_el["punct_sent"]
76
                if bool(punct_text) and is_same(cleanup(b_el["punct_sent"]), cleanup(subsample["text"]), 0.95):
77
                    subsample["punct_text"] = punct_text
78
                else:
79
                    subsample["hyp_text"] = punct_text
80
                new_samples += [subsample]
81
        return new_samples
82
    except Exception as exc:
83
        raise exc
84
        return samples
85
    return samples
86

87

88
async def worker(url, data, batch_n):
89
    new_samples = []
90
    async with aiohttp.ClientSession() as session:
91
        while data["samples"]:
92
            batch = [data["samples"].pop() for _ in range(batch_n) if data["samples"]]
93
            new_samples += await aio_request(session, url, batch)
94
    return new_samples
95

96

97
async def load_bar(data):
98
    with tqdm.tqdm(total=len(data["samples"])) as pbar:
99
        cur_len = len(data["samples"])
100
        while data["samples"]:
101
            if cur_len != len(data["samples"]):
102
                pbar.update(cur_len - len(data["samples"]))
103
                cur_len = len(data["samples"])
104
            await asyncio.sleep(1)
105

106

107
async def load_data(samples, batch_n=10, worker_n=10):
108
    url = "http://a737ad642c7cc4356a543c2c58779eb6-1162604519.us-west-2.elb.amazonaws.com/sentseg/sentseg"
109
    data = {}
110
    data["samples"] = samples
111
    new_samples = []
112
    tasks = [asyncio.ensure_future(load_bar(data))]
113
    tasks += [asyncio.ensure_future(worker(url, data, batch_n)) for _ in range(worker_n)]
114
    new_samples = await asyncio.gather(*tasks)
115
    new_samples = [sample for sample in new_samples if sample]
116
    new_samples = sum(new_samples, [])
117
    return new_samples
118

119

120
# %%
121

122
loop = asyncio.get_event_loop()
123

124

125
unhandled_responses = copy.deepcopy(marked_responses)
126
handled_responses = [resp for resp in unhandled_responses if "punct_text" in resp]
127
i = 0
128
while unhandled_responses:
129
    i += 1
130
    if i > 5:
131
        break
132
    wip_responses = loop.run_until_complete(load_data(unhandled_responses, batch_n=10, worker_n=10))
133
    unhandled_responses = [resp for resp in wip_responses if "punct_text" not in resp]
134
    handled_responses += [resp for resp in wip_responses if "punct_text" in resp]
135
handled_responses = handled_responses + unhandled_responses
136
json.dump(handled_responses, open("cache_v2.json", "wt"), ensure_ascii=False, indent=4)
137
# %%
138

139
handled_responses = json.load(open("cache_v2.json", "rt"))
140

141
# %%
142
handled_responses = {
143
    resp["index"]: format_text(resp.get("punct_text", resp.get("hyp_text", resp["text"]))) for resp in handled_responses
144
}
145
max_index = max(list(handled_responses.keys()))
146
handled_responses = [handled_responses[i] for i in range(max_index + 1)]
147

148
pickle.dump([response_encodings, handled_responses], output_db_file.open("wb"))
149

150
# %%
151
print("old:")
152
print(responses[:20])
153
print("new:")
154
print(handled_responses[:20])
155
# %%
156
loop.close()
157

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.