6
from allennlp.common import Params
7
from allennlp.data import Vocabulary
8
from allennlp.models import Model
9
from allennlp.common.checks import ConfigurationError
10
from allennlp.common.testing import AllenNlpTestCase, requires_multi_gpu
11
from allennlp.commands.find_learning_rate import (
13
find_learning_rate_from_args,
14
find_learning_rate_model,
17
from allennlp.training import Trainer
18
from allennlp.training.util import data_loaders_from_params
21
def is_matplotlib_installed():
29
class TestFindLearningRate(AllenNlpTestCase):
30
def setup_method(self):
31
super().setup_method()
32
self.params = lambda: Params(
35
"type": "simple_tagger",
36
"text_field_embedder": {
37
"token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
39
"encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
41
"dataset_reader": {"type": "sequence_tagging"},
42
"train_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
43
"validation_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
44
"data_loader": {"batch_size": 2},
45
"trainer": {"cuda_device": -1, "num_epochs": 2, "optimizer": "adam"},
49
@pytest.mark.skipif(not is_matplotlib_installed(), reason="matplotlib dependency is optional")
50
def test_find_learning_rate(self):
51
find_learning_rate_model(
53
os.path.join(self.TEST_DIR, "test_find_learning_rate"),
63
serialization_dir2 = os.path.join(self.TEST_DIR, "empty_directory")
64
assert not os.path.exists(serialization_dir2)
65
os.makedirs(serialization_dir2)
66
find_learning_rate_model(
78
serialization_dir3 = os.path.join(self.TEST_DIR, "non_empty_directory")
79
assert not os.path.exists(serialization_dir3)
80
os.makedirs(serialization_dir3)
81
with open(os.path.join(serialization_dir3, "README.md"), "w") as f:
84
with pytest.raises(ConfigurationError):
85
find_learning_rate_model(
97
find_learning_rate_model(
104
stopping_factor=None,
108
def test_find_learning_rate_args(self):
109
parser = argparse.ArgumentParser(description="Testing")
110
subparsers = parser.add_subparsers(title="Commands", metavar="")
111
FindLearningRate().add_subparser(subparsers)
113
for serialization_arg in ["-s", "--serialization-dir"]:
114
raw_args = ["find-lr", "path/to/params", serialization_arg, "serialization_dir"]
116
args = parser.parse_args(raw_args)
118
assert args.func == find_learning_rate_from_args
119
assert args.param_path == "path/to/params"
120
assert args.serialization_dir == "serialization_dir"
123
with pytest.raises(SystemExit) as cm:
124
parser.parse_args(["find-lr", "-s", "serialization_dir"])
125
assert cm.exception.code == 2
128
with pytest.raises(SystemExit) as cm:
129
parser.parse_args(["find-lr", "path/to/params"])
130
assert cm.exception.code == 2
133
def test_find_learning_rate_multi_gpu(self):
134
params = self.params()
135
del params["trainer"]["cuda_device"]
136
params["distributed"] = Params({})
137
params["distributed"]["cuda_devices"] = [0, 1]
139
with pytest.raises(AssertionError) as execinfo:
140
find_learning_rate_model(
142
os.path.join(self.TEST_DIR, "test_find_learning_rate_multi_gpu"),
147
stopping_factor=None,
150
assert "DistributedDataParallel" in str(execinfo.value)
153
class TestSearchLearningRate(AllenNlpTestCase):
154
def setup_method(self):
155
super().setup_method()
159
"type": "simple_tagger",
160
"text_field_embedder": {
161
"token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
163
"encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
165
"dataset_reader": {"type": "sequence_tagging"},
166
"train_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
167
"validation_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
168
"data_loader": {"batch_size": 2},
169
"trainer": {"cuda_device": -1, "num_epochs": 2, "optimizer": "adam"},
172
all_data_loaders = data_loaders_from_params(params)
173
vocab = Vocabulary.from_params(
174
params.pop("vocabulary", {}),
177
for data_loader in all_data_loaders.values()
178
for instance in data_loader.iter_instances()
181
model = Model.from_params(vocab=vocab, params=params.pop("model"))
183
data_loader = all_data_loaders["train"]
184
data_loader.index_with(vocab)
186
trainer_params = params.pop("trainer")
187
serialization_dir = os.path.join(self.TEST_DIR, "test_search_learning_rate")
189
self.trainer = Trainer.from_params(
191
serialization_dir=serialization_dir,
192
data_loader=data_loader,
193
params=trainer_params,
194
validation_data=None,
195
validation_iterator=None,
198
def test_search_learning_rate_with_num_batches_less_than_ten(self):
199
with pytest.raises(ConfigurationError):
200
search_learning_rate(self.trainer, num_batches=9)
202
def test_search_learning_rate_linear_steps(self):
203
learning_rates_losses = search_learning_rate(self.trainer, linear_steps=True)
204
assert len(learning_rates_losses) > 1
206
def test_search_learning_rate_without_stopping_factor(self):
207
learning_rates, losses = search_learning_rate(
208
self.trainer, num_batches=100, stopping_factor=None
210
assert len(learning_rates) == 101
211
assert len(losses) == 101