dream
156 строк · 4.6 Кб
1# %%
2
3import tqdm4import pickle5import pathlib6import json7import difflib8import logging9import aiohttp10import asyncio11import re12import copy13
14
15logger = logging.getLogger(__name__)16
17
18# %%
19data_dir = pathlib.Path("data")20work_dir = pathlib.Path("tmp")21
22banned_responses_file = data_dir / "banned_responses_v2.json"23db_file = work_dir / "replies_v2.pkl"24output_db_file = work_dir / "replies_v3.pkl"25
26chache_file = work_dir / "chache.json"27dropped_responses_file = work_dir / "dropped_responses.json"28log_file = work_dir / "logs.txt"29
30response_encodings, responses = pickle.load(db_file.open("rb"))31# %%
32
33marked_responses = [{"text": res, "index": i} for i, res in enumerate(responses)]34
35
36# %%
37format_1 = re.compile(r"\s+(,|\.|!|\?)")38format_2 = re.compile(r"\s+(’)\s+")39
40
41def format_text(text):42text = re.sub(format_1, r"\1", text)43text = re.sub(format_2, r"'", text)44return text45
46
47other_symbols_compiled = re.compile(r"[^a-zA-Z0-9\- ]", re.IGNORECASE)48space_compiled = re.compile(r"\s+", re.IGNORECASE)49
50
51def cleanup(text):52cleaned = re.sub(other_symbols_compiled, "", text)53cleaned = re.sub(space_compiled, " ", cleaned)54return cleaned.strip()55
56
57def get_data_iter(data, split_n):58ranges = range(0, len(data) + split_n, split_n)59for begin_i, end_i in zip(ranges, ranges[1:]):60yield data[begin_i:end_i]61
62
63def is_same(ground_truth, hypothesis, ratio=0.9):64res_ratio = difflib.SequenceMatcher(None, ground_truth.split(), hypothesis.split()).ratio()65return res_ratio >= ratio66
67
68async def aio_request(session, url, samples):69try:70new_samples = []71async with session.post(url, json={"sentences": [subsample["text"] for subsample in samples]}) as resp:72batch = await resp.json()73assert len(batch) == len(samples)74for b_el, subsample in zip(batch, samples):75punct_text = b_el["punct_sent"]76if bool(punct_text) and is_same(cleanup(b_el["punct_sent"]), cleanup(subsample["text"]), 0.95):77subsample["punct_text"] = punct_text78else:79subsample["hyp_text"] = punct_text80new_samples += [subsample]81return new_samples82except Exception as exc:83raise exc84return samples85return samples86
87
88async def worker(url, data, batch_n):89new_samples = []90async with aiohttp.ClientSession() as session:91while data["samples"]:92batch = [data["samples"].pop() for _ in range(batch_n) if data["samples"]]93new_samples += await aio_request(session, url, batch)94return new_samples95
96
97async def load_bar(data):98with tqdm.tqdm(total=len(data["samples"])) as pbar:99cur_len = len(data["samples"])100while data["samples"]:101if cur_len != len(data["samples"]):102pbar.update(cur_len - len(data["samples"]))103cur_len = len(data["samples"])104await asyncio.sleep(1)105
106
107async def load_data(samples, batch_n=10, worker_n=10):108url = "http://a737ad642c7cc4356a543c2c58779eb6-1162604519.us-west-2.elb.amazonaws.com/sentseg/sentseg"109data = {}110data["samples"] = samples111new_samples = []112tasks = [asyncio.ensure_future(load_bar(data))]113tasks += [asyncio.ensure_future(worker(url, data, batch_n)) for _ in range(worker_n)]114new_samples = await asyncio.gather(*tasks)115new_samples = [sample for sample in new_samples if sample]116new_samples = sum(new_samples, [])117return new_samples118
119
120# %%
121
122loop = asyncio.get_event_loop()123
124
125unhandled_responses = copy.deepcopy(marked_responses)126handled_responses = [resp for resp in unhandled_responses if "punct_text" in resp]127i = 0128while unhandled_responses:129i += 1130if i > 5:131break132wip_responses = loop.run_until_complete(load_data(unhandled_responses, batch_n=10, worker_n=10))133unhandled_responses = [resp for resp in wip_responses if "punct_text" not in resp]134handled_responses += [resp for resp in wip_responses if "punct_text" in resp]135handled_responses = handled_responses + unhandled_responses136json.dump(handled_responses, open("cache_v2.json", "wt"), ensure_ascii=False, indent=4)137# %%
138
139handled_responses = json.load(open("cache_v2.json", "rt"))140
141# %%
142handled_responses = {143resp["index"]: format_text(resp.get("punct_text", resp.get("hyp_text", resp["text"]))) for resp in handled_responses144}
145max_index = max(list(handled_responses.keys()))146handled_responses = [handled_responses[i] for i in range(max_index + 1)]147
148pickle.dump([response_encodings, handled_responses], output_db_file.open("wb"))149
150# %%
151print("old:")152print(responses[:20])153print("new:")154print(handled_responses[:20])155# %%
156loop.close()157