7
from contextlib import redirect_stdout
9
from allennlp.commands import main
10
from allennlp.common.testing import AllenNlpTestCase
13
class TestPrintResults(AllenNlpTestCase):
14
def setup_method(self):
15
super().setup_method()
17
self.out_dir1 = pathlib.Path(tempfile.mkdtemp(prefix="hi"))
18
self.out_dir2 = pathlib.Path(tempfile.mkdtemp(prefix="hi"))
20
self.directory1 = self.TEST_DIR / "results1"
21
self.directory2 = self.TEST_DIR / "results2"
22
self.directory3 = self.TEST_DIR / "results3"
23
os.makedirs(self.directory1)
24
os.makedirs(self.directory2)
25
os.makedirs(self.directory3)
27
{"train": 1, "test": 2, "dev": 3},
28
open(os.path.join(self.directory1 / "metrics.json"), "w+"),
31
{"train": 4, "dev": 5}, open(os.path.join(self.directory2 / "metrics.json"), "w+")
34
{"train": 6, "dev": 7}, open(os.path.join(self.directory3 / "cool_metrics.json"), "w+")
37
def test_print_results(self):
48
with io.StringIO() as buf, redirect_stdout(buf):
50
output = buf.getvalue()
52
lines = output.strip().split("\n")
53
assert lines[0] == "model_run, train, dev, test"
56
(str(self.directory1) + "/metrics.json", "1", "3", "2"),
57
(str(self.directory2) + "/metrics.json", "4", "5", "N/A"),
59
results = {tuple(line.split(", ")) for line in lines[1:]}
60
assert results == expected_results
62
def test_print_results_with_metrics_filename(self):
75
with io.StringIO() as buf, redirect_stdout(buf):
77
output = buf.getvalue()
79
lines = output.strip().split("\n")
80
assert lines[0] == "model_run, train, dev, test"
82
expected_results = {(str(self.directory3) + "/cool_metrics.json", "6", "7", "N/A")}
83
results = {tuple(line.split(", ")) for line in lines[1:]}
84
assert results == expected_results