dream
42 строки · 1.1 Кб
1import requests
2import logging
3import json
4import numpy as np
5
6
7logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
8logger = logging.getLogger(__name__)
9
10URL = f"http://0.0.0.0:8110/batch_model"
11
12dialogs_path = "test_data.json"
13with open(dialogs_path) as f:
14dialogs = json.load(f)
15
16test_config = {"contexts": [], "hypotheses": []}
17dialog_ids = []
18for i, sample in enumerate(dialogs):
19for hyp in sample["hyp"]:
20test_config["contexts"] += [sample["context"]]
21test_config["hypotheses"] += [hyp]
22dialog_ids += [i]
23dialog_ids = np.array(dialog_ids)
24
25
26def main_test():
27batch_responses = requests.post(URL, json=test_config).json()[0]["batch"]
28batch_responses = np.array(batch_responses)
29
30for i, sample in enumerate(dialogs):
31curr_responses = batch_responses[dialog_ids == i]
32pred_best_hyp_id = np.argmax(curr_responses)
33
34assert sample["hyp"][pred_best_hyp_id], print(
35f"Current responses: {curr_responses}, pred best resp id: {pred_best_hyp_id}"
36)
37
38logger.info("Success!")
39
40
41if __name__ == "__main__":
42main_test()
43