allennlp

Форк
0
/
print_results_test.py 
84 строки · 2.6 Кб
1
import os
2
import json
3
import sys
4
import pathlib
5
import tempfile
6
import io
7
from contextlib import redirect_stdout
8

9
from allennlp.commands import main
10
from allennlp.common.testing import AllenNlpTestCase
11

12

13
class TestPrintResults(AllenNlpTestCase):
14
    def setup_method(self):
15
        super().setup_method()
16

17
        self.out_dir1 = pathlib.Path(tempfile.mkdtemp(prefix="hi"))
18
        self.out_dir2 = pathlib.Path(tempfile.mkdtemp(prefix="hi"))
19

20
        self.directory1 = self.TEST_DIR / "results1"
21
        self.directory2 = self.TEST_DIR / "results2"
22
        self.directory3 = self.TEST_DIR / "results3"
23
        os.makedirs(self.directory1)
24
        os.makedirs(self.directory2)
25
        os.makedirs(self.directory3)
26
        json.dump(
27
            {"train": 1, "test": 2, "dev": 3},
28
            open(os.path.join(self.directory1 / "metrics.json"), "w+"),
29
        )
30
        json.dump(
31
            {"train": 4, "dev": 5}, open(os.path.join(self.directory2 / "metrics.json"), "w+")
32
        )
33
        json.dump(
34
            {"train": 6, "dev": 7}, open(os.path.join(self.directory3 / "cool_metrics.json"), "w+")
35
        )
36

37
    def test_print_results(self):
38
        kebab_args = [
39
            "__main__.py",
40
            "print-results",
41
            str(self.TEST_DIR),
42
            "--keys",
43
            "train",
44
            "dev",
45
            "test",
46
        ]
47
        sys.argv = kebab_args
48
        with io.StringIO() as buf, redirect_stdout(buf):
49
            main()
50
            output = buf.getvalue()
51

52
        lines = output.strip().split("\n")
53
        assert lines[0] == "model_run, train, dev, test"
54

55
        expected_results = {
56
            (str(self.directory1) + "/metrics.json", "1", "3", "2"),
57
            (str(self.directory2) + "/metrics.json", "4", "5", "N/A"),
58
        }
59
        results = {tuple(line.split(", ")) for line in lines[1:]}
60
        assert results == expected_results
61

62
    def test_print_results_with_metrics_filename(self):
63
        kebab_args = [
64
            "__main__.py",
65
            "print-results",
66
            str(self.TEST_DIR),
67
            "--keys",
68
            "train",
69
            "dev",
70
            "test",
71
            "--metrics-filename",
72
            "cool_metrics.json",
73
        ]
74
        sys.argv = kebab_args
75
        with io.StringIO() as buf, redirect_stdout(buf):
76
            main()
77
            output = buf.getvalue()
78

79
        lines = output.strip().split("\n")
80
        assert lines[0] == "model_run, train, dev, test"
81

82
        expected_results = {(str(self.directory3) + "/cool_metrics.json", "6", "7", "N/A")}
83
        results = {tuple(line.split(", ")) for line in lines[1:]}
84
        assert results == expected_results
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.