dream
52 строки · 1.7 Кб
1import argparse2import requests3import json4
5parser = argparse.ArgumentParser()6parser.add_argument("-pred_f", "--pred_file", type=str, default="tests/test_results.json")7parser.add_argument("-true_f", "--true_file", type=str, default="tests/test_tasks.json")8parser.add_argument("-from_url", "--from_url", type=str, default="http://0.0.0.0:8007/transfertransfo")9
10
11def get_response(url, personality, history):12try:13data = requests.post(url, json={"personality": [personality], "utterances_histories": [history]})14return data.json()[0]15except Exception:16return ("", 0)17
18
19def main():20args = parser.parse_args()21cntx = json.load(open(args.true_file, "rt"))22personality = cntx["personality"]23
24valid_flags = []25res_tasks = []26for task in cntx["tasks"]:27responses = []28for _ in range(task["num_try"]):29responses.append(get_response(args.from_url, personality, task["utterances_histories"]))30responses = sorted(responses, key=lambda x: -x[1])31responses = [32{33"valid": not task["targets"] or bool([True for tgt in task["targets"] if tgt in res]),34"response": res,35"confidence": conf,36}37for res, conf in responses38if res39]40task["responses"] = responses41valid = bool([True for res in responses if res["valid"]])42task["valid"] = valid43res_tasks.append(task)44valid_flags.append(not task["targets"] or valid)45cntx["tasks"] = res_tasks46json.dump(cntx, open(args.pred_file, "wt", encoding="utf-8"), indent=4)47for valid in valid_flags:48assert valid49
50
51if __name__ == "__main__":52main()53