1
from datetime import timedelta
3
from collections import OrderedDict
8
from allennlp.common import util
9
from allennlp.common.testing import AllenNlpTestCase
10
from allennlp.common.util import push_python_path
19
return {"sanitizable": True}
22
class TestCommonUtils(AllenNlpTestCase):
23
def test_group_by_count(self):
24
assert util.group_by_count([1, 2, 3, 4, 5, 6, 7], 3, 20) == [
30
def test_lazy_groups_of(self):
31
xs = [1, 2, 3, 4, 5, 6, 7]
32
groups = util.lazy_groups_of(iter(xs), group_size=3)
33
assert next(groups) == [1, 2, 3]
34
assert next(groups) == [4, 5, 6]
35
assert next(groups) == [7]
36
with pytest.raises(StopIteration):
39
def test_pad_sequence_to_length(self):
40
assert util.pad_sequence_to_length([1, 2, 3], 5) == [1, 2, 3, 0, 0]
41
assert util.pad_sequence_to_length([1, 2, 3], 5, default_value=lambda: 2) == [1, 2, 3, 2, 2]
42
assert util.pad_sequence_to_length([1, 2, 3], 5, padding_on_right=False) == [0, 0, 1, 2, 3]
44
def test_namespace_match(self):
45
assert util.namespace_match("*tags", "tags")
46
assert util.namespace_match("*tags", "passage_tags")
47
assert util.namespace_match("*tags", "question_tags")
48
assert util.namespace_match("tokens", "tokens")
49
assert not util.namespace_match("tokens", "stemmed_tokens")
51
def test_sanitize(self):
52
assert util.sanitize(torch.Tensor([1, 2])) == [1, 2]
53
assert util.sanitize(torch.LongTensor([1, 2])) == [1, 2]
55
with pytest.raises(ValueError):
56
util.sanitize(Unsanitizable())
58
assert util.sanitize(Sanitizable()) == {"sanitizable": True}
60
x = util.sanitize({1, 2, 3})
61
assert isinstance(x, list)
64
def test_import_submodules(self):
65
(self.TEST_DIR / "mymodule").mkdir()
66
(self.TEST_DIR / "mymodule" / "__init__.py").touch()
67
(self.TEST_DIR / "mymodule" / "submodule").mkdir()
68
(self.TEST_DIR / "mymodule" / "submodule" / "__init__.py").touch()
69
(self.TEST_DIR / "mymodule" / "submodule" / "subsubmodule.py").touch()
71
with push_python_path(self.TEST_DIR):
72
assert "mymodule" not in sys.modules
73
assert "mymodule.submodule" not in sys.modules
75
util.import_module_and_submodules("mymodule")
77
assert "mymodule" in sys.modules
78
assert "mymodule.submodule" in sys.modules
79
assert "mymodule.submodule.subsubmodule" in sys.modules
81
def test_get_frozen_and_tunable_parameter_names(self):
82
model = torch.nn.Sequential(
83
OrderedDict([("conv", torch.nn.Conv1d(5, 5, 5)), ("linear", torch.nn.Linear(5, 10))])
85
named_parameters = dict(model.named_parameters())
86
named_parameters["linear.weight"].requires_grad_(False)
87
named_parameters["linear.bias"].requires_grad_(False)
89
frozen_parameter_names,
90
tunable_parameter_names,
91
) = util.get_frozen_and_tunable_parameter_names(model)
92
assert set(frozen_parameter_names) == {"linear.weight", "linear.bias"}
93
assert set(tunable_parameter_names) == {"conv.weight", "conv.bias"}
95
def test_sanitize_ptb_tokenized_string(self):
96
def create_surrounding_test_case(start_ptb_token, end_ptb_token, start_token, end_token):
98
"a {} b c {} d".format(start_ptb_token, end_ptb_token),
99
"a {}b c{} d".format(start_token, end_token),
102
def create_fwd_token_test_case(fwd_token):
103
return "a {} b".format(fwd_token), "a {}b".format(fwd_token)
105
def create_backward_token_test_case(backward_token):
106
return "a {} b".format(backward_token), "a{} b".format(backward_token)
108
punct_forward = {"`", "$", "#"}
109
punct_backward = {".", ",", "!", "?", ":", ";", "%", "'"}
113
create_surrounding_test_case("-lrb-", "-rrb-", "(", ")"),
114
create_surrounding_test_case("-lsb-", "-rsb-", "[", "]"),
115
create_surrounding_test_case("-lcb-", "-rcb-", "{", "}"),
117
create_surrounding_test_case("-lsb-", "-rcb-", "[", "}"),
119
create_surrounding_test_case("-LsB-", "-rcB-", "[", "}"),
121
create_surrounding_test_case("``", "''", '"', '"'),
123
create_surrounding_test_case("<s>", "</s>", "", ""),
125
*[create_fwd_token_test_case(t) for t in punct_forward],
127
*[create_backward_token_test_case(t) for t in punct_backward],
131
("I do n't", "I don't"),
136
("A b C d", "A b C d"),
139
for ptb_string, expected in test_cases:
140
actual = util.sanitize_ptb_tokenized_string(ptb_string)
141
assert actual == expected
143
def test_cycle_iterator_function(self):
144
global cycle_iterator_function_calls
145
cycle_iterator_function_calls = 0
148
global cycle_iterator_function_calls
149
cycle_iterator_function_calls += 1
153
iterator = iter(util.cycle_iterator_function(one_and_two))
156
assert cycle_iterator_function_calls == 0
158
values = [next(iterator) for _ in range(5)]
159
assert values == [1, 2, 1, 2, 1]
162
assert cycle_iterator_function_calls == 3
165
@pytest.mark.parametrize(
169
(int(1.2 * 1024), "1.2K"),
171
(120 * 1024, "120K"),
172
(int(1.2 * 1024 * 1024), "1.2M"),
173
(12 * 1024 * 1024, "12M"),
174
(120 * 1024 * 1024, "120M"),
175
(int(1.2 * 1024 * 1024 * 1024), "1.2G"),
176
(12 * 1024 * 1024 * 1024, "12G"),
179
def test_format_size(size: int, result: str):
180
assert util.format_size(size) == result
183
@pytest.mark.parametrize(
186
(timedelta(days=2, hours=3), "2 days"),
187
(timedelta(days=1, hours=3), "1 day"),
188
(timedelta(hours=3, minutes=12), "3 hours"),
189
(timedelta(hours=1, minutes=12), "1 hour, 12 mins"),
190
(timedelta(minutes=12), "12 mins"),
193
def test_format_timedelta(td: timedelta, result: str):
194
assert util.format_timedelta(td) == result