dream
54 строки · 1.3 Кб
1import argparse
2import json
3import pathlib
4
5import numpy as np
6import requests
7import tqdm
8
9# TGT_URL = os.getenv("TGT_URL", "http://localhost:8029/convert_reddit")
10# N_REQUESTS = int(os.getenv("N_REQUESTS", 5))
11# OUT_FILE = str(os.getenv("OUT_FILE", "confidences.npy"))
12
13parser = argparse.ArgumentParser()
14parser.add_argument(
15"--url",
16help="skill_url",
17default="http://localhost:8029/convert_reddit",
18)
19parser.add_argument(
20"--questions_json_file",
21type=pathlib.Path,
22help="path to a json file",
23default="tests/test_question_tasks.json",
24)
25parser.add_argument(
26"-o",
27"--npy_file_path",
28type=pathlib.Path,
29help="path to npy file",
30default="confidences.npy",
31)
32args = parser.parse_args()
33data = json.load(args.questions_json_file.open())
34
35
36def history_gen(dialogs):
37for dialog in dialogs:
38for i in range(1, len(dialog) + 1):
39history = dialog[:i]
40yield history
41
42
43confidences = []
44for task in tqdm.tqdm(data["tasks"]):
45response = {}
46for _ in range(1):
47res = requests.post(
48args.url,
49json={"personality": [data["personality"]], "utterances_histories": [task["utterances_histories"]]},
50).json()[0]
51response[res[0]] = res[1]
52confidences.extend(response.values())
53
54np.save(str(args.npy_file_path), confidences)
55