allennlp

Форк
0
/
util_test.py 
194 строки · 7.1 Кб
1
from datetime import timedelta
2
import sys
3
from collections import OrderedDict
4

5
import pytest
6
import torch
7

8
from allennlp.common import util
9
from allennlp.common.testing import AllenNlpTestCase
10
from allennlp.common.util import push_python_path
11

12

13
class Unsanitizable:
14
    pass
15

16

17
class Sanitizable:
18
    def to_json(self):
19
        return {"sanitizable": True}
20

21

22
class TestCommonUtils(AllenNlpTestCase):
23
    def test_group_by_count(self):
24
        assert util.group_by_count([1, 2, 3, 4, 5, 6, 7], 3, 20) == [
25
            [1, 2, 3],
26
            [4, 5, 6],
27
            [7, 20, 20],
28
        ]
29

30
    def test_lazy_groups_of(self):
31
        xs = [1, 2, 3, 4, 5, 6, 7]
32
        groups = util.lazy_groups_of(iter(xs), group_size=3)
33
        assert next(groups) == [1, 2, 3]
34
        assert next(groups) == [4, 5, 6]
35
        assert next(groups) == [7]
36
        with pytest.raises(StopIteration):
37
            _ = next(groups)
38

39
    def test_pad_sequence_to_length(self):
40
        assert util.pad_sequence_to_length([1, 2, 3], 5) == [1, 2, 3, 0, 0]
41
        assert util.pad_sequence_to_length([1, 2, 3], 5, default_value=lambda: 2) == [1, 2, 3, 2, 2]
42
        assert util.pad_sequence_to_length([1, 2, 3], 5, padding_on_right=False) == [0, 0, 1, 2, 3]
43

44
    def test_namespace_match(self):
45
        assert util.namespace_match("*tags", "tags")
46
        assert util.namespace_match("*tags", "passage_tags")
47
        assert util.namespace_match("*tags", "question_tags")
48
        assert util.namespace_match("tokens", "tokens")
49
        assert not util.namespace_match("tokens", "stemmed_tokens")
50

51
    def test_sanitize(self):
52
        assert util.sanitize(torch.Tensor([1, 2])) == [1, 2]
53
        assert util.sanitize(torch.LongTensor([1, 2])) == [1, 2]
54

55
        with pytest.raises(ValueError):
56
            util.sanitize(Unsanitizable())
57

58
        assert util.sanitize(Sanitizable()) == {"sanitizable": True}
59

60
        x = util.sanitize({1, 2, 3})
61
        assert isinstance(x, list)
62
        assert len(x) == 3
63

64
    def test_import_submodules(self):
65
        (self.TEST_DIR / "mymodule").mkdir()
66
        (self.TEST_DIR / "mymodule" / "__init__.py").touch()
67
        (self.TEST_DIR / "mymodule" / "submodule").mkdir()
68
        (self.TEST_DIR / "mymodule" / "submodule" / "__init__.py").touch()
69
        (self.TEST_DIR / "mymodule" / "submodule" / "subsubmodule.py").touch()
70

71
        with push_python_path(self.TEST_DIR):
72
            assert "mymodule" not in sys.modules
73
            assert "mymodule.submodule" not in sys.modules
74

75
            util.import_module_and_submodules("mymodule")
76

77
            assert "mymodule" in sys.modules
78
            assert "mymodule.submodule" in sys.modules
79
            assert "mymodule.submodule.subsubmodule" in sys.modules
80

81
    def test_get_frozen_and_tunable_parameter_names(self):
82
        model = torch.nn.Sequential(
83
            OrderedDict([("conv", torch.nn.Conv1d(5, 5, 5)), ("linear", torch.nn.Linear(5, 10))])
84
        )
85
        named_parameters = dict(model.named_parameters())
86
        named_parameters["linear.weight"].requires_grad_(False)
87
        named_parameters["linear.bias"].requires_grad_(False)
88
        (
89
            frozen_parameter_names,
90
            tunable_parameter_names,
91
        ) = util.get_frozen_and_tunable_parameter_names(model)
92
        assert set(frozen_parameter_names) == {"linear.weight", "linear.bias"}
93
        assert set(tunable_parameter_names) == {"conv.weight", "conv.bias"}
94

95
    def test_sanitize_ptb_tokenized_string(self):
96
        def create_surrounding_test_case(start_ptb_token, end_ptb_token, start_token, end_token):
97
            return (
98
                "a {} b c {} d".format(start_ptb_token, end_ptb_token),
99
                "a {}b c{} d".format(start_token, end_token),
100
            )
101

102
        def create_fwd_token_test_case(fwd_token):
103
            return "a {} b".format(fwd_token), "a {}b".format(fwd_token)
104

105
        def create_backward_token_test_case(backward_token):
106
            return "a {} b".format(backward_token), "a{} b".format(backward_token)
107

108
        punct_forward = {"`", "$", "#"}
109
        punct_backward = {".", ",", "!", "?", ":", ";", "%", "'"}
110

111
        test_cases = [
112
            # Parentheses
113
            create_surrounding_test_case("-lrb-", "-rrb-", "(", ")"),
114
            create_surrounding_test_case("-lsb-", "-rsb-", "[", "]"),
115
            create_surrounding_test_case("-lcb-", "-rcb-", "{", "}"),
116
            # Parentheses don't have to match
117
            create_surrounding_test_case("-lsb-", "-rcb-", "[", "}"),
118
            # Also check that casing doesn't matter
119
            create_surrounding_test_case("-LsB-", "-rcB-", "[", "}"),
120
            # Quotes
121
            create_surrounding_test_case("``", "''", '"', '"'),
122
            # Start/end tokens
123
            create_surrounding_test_case("<s>", "</s>", "", ""),
124
            # Tokens that merge forward
125
            *[create_fwd_token_test_case(t) for t in punct_forward],
126
            # Tokens that merge backward
127
            *[create_backward_token_test_case(t) for t in punct_backward],
128
            # Merge tokens starting with ' backwards
129
            ("I 'm", "I'm"),
130
            # Merge tokens backwards when matching (n't or na) (special cases, parentheses behave in the same way)
131
            ("I do n't", "I don't"),
132
            ("gon na", "gonna"),
133
            # Also make sure casing is preserved
134
            ("gon NA", "gonNA"),
135
            # This is a no op
136
            ("A b C d", "A b C d"),
137
        ]
138

139
        for ptb_string, expected in test_cases:
140
            actual = util.sanitize_ptb_tokenized_string(ptb_string)
141
            assert actual == expected
142

143
    def test_cycle_iterator_function(self):
144
        global cycle_iterator_function_calls
145
        cycle_iterator_function_calls = 0
146

147
        def one_and_two():
148
            global cycle_iterator_function_calls
149
            cycle_iterator_function_calls += 1
150
            for i in [1, 2]:
151
                yield i
152

153
        iterator = iter(util.cycle_iterator_function(one_and_two))
154

155
        # Function calls should be lazy.
156
        assert cycle_iterator_function_calls == 0
157

158
        values = [next(iterator) for _ in range(5)]
159
        assert values == [1, 2, 1, 2, 1]
160
        # This is the difference between cycle_iterator_function and itertools.cycle.  We'd only see
161
        # 1 here with itertools.cycle.
162
        assert cycle_iterator_function_calls == 3
163

164

165
@pytest.mark.parametrize(
166
    "size, result",
167
    [
168
        (12, "12B"),
169
        (int(1.2 * 1024), "1.2K"),
170
        (12 * 1024, "12K"),
171
        (120 * 1024, "120K"),
172
        (int(1.2 * 1024 * 1024), "1.2M"),
173
        (12 * 1024 * 1024, "12M"),
174
        (120 * 1024 * 1024, "120M"),
175
        (int(1.2 * 1024 * 1024 * 1024), "1.2G"),
176
        (12 * 1024 * 1024 * 1024, "12G"),
177
    ],
178
)
179
def test_format_size(size: int, result: str):
180
    assert util.format_size(size) == result
181

182

183
@pytest.mark.parametrize(
184
    "td, result",
185
    [
186
        (timedelta(days=2, hours=3), "2 days"),
187
        (timedelta(days=1, hours=3), "1 day"),
188
        (timedelta(hours=3, minutes=12), "3 hours"),
189
        (timedelta(hours=1, minutes=12), "1 hour, 12 mins"),
190
        (timedelta(minutes=12), "12 mins"),
191
    ],
192
)
193
def test_format_timedelta(td: timedelta, result: str):
194
    assert util.format_timedelta(td) == result
195

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.