allennlp

Форк
0
/
find_learning_rate_test.py 
211 строк · 7.8 Кб
1
import argparse
2
import os
3

4
import pytest
5

6
from allennlp.common import Params
7
from allennlp.data import Vocabulary
8
from allennlp.models import Model
9
from allennlp.common.checks import ConfigurationError
10
from allennlp.common.testing import AllenNlpTestCase, requires_multi_gpu
11
from allennlp.commands.find_learning_rate import (
12
    search_learning_rate,
13
    find_learning_rate_from_args,
14
    find_learning_rate_model,
15
    FindLearningRate,
16
)
17
from allennlp.training import Trainer
18
from allennlp.training.util import data_loaders_from_params
19

20

21
def is_matplotlib_installed():
22
    try:
23
        import matplotlib  # noqa: F401 - Matplotlib is optional.
24
    except:  # noqa: E722. Any exception means we don't have a working matplotlib.
25
        return False
26
    return True
27

28

29
class TestFindLearningRate(AllenNlpTestCase):
30
    def setup_method(self):
31
        super().setup_method()
32
        self.params = lambda: Params(
33
            {
34
                "model": {
35
                    "type": "simple_tagger",
36
                    "text_field_embedder": {
37
                        "token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
38
                    },
39
                    "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
40
                },
41
                "dataset_reader": {"type": "sequence_tagging"},
42
                "train_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
43
                "validation_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
44
                "data_loader": {"batch_size": 2},
45
                "trainer": {"cuda_device": -1, "num_epochs": 2, "optimizer": "adam"},
46
            }
47
        )
48

49
    @pytest.mark.skipif(not is_matplotlib_installed(), reason="matplotlib dependency is optional")
50
    def test_find_learning_rate(self):
51
        find_learning_rate_model(
52
            self.params(),
53
            os.path.join(self.TEST_DIR, "test_find_learning_rate"),
54
            start_lr=1e-5,
55
            end_lr=1,
56
            num_batches=100,
57
            linear_steps=True,
58
            stopping_factor=None,
59
            force=False,
60
        )
61

62
        # It's OK if serialization dir exists but is empty:
63
        serialization_dir2 = os.path.join(self.TEST_DIR, "empty_directory")
64
        assert not os.path.exists(serialization_dir2)
65
        os.makedirs(serialization_dir2)
66
        find_learning_rate_model(
67
            self.params(),
68
            serialization_dir2,
69
            start_lr=1e-5,
70
            end_lr=1,
71
            num_batches=100,
72
            linear_steps=True,
73
            stopping_factor=None,
74
            force=False,
75
        )
76

77
        # It's not OK if serialization dir exists and has junk in it non-empty:
78
        serialization_dir3 = os.path.join(self.TEST_DIR, "non_empty_directory")
79
        assert not os.path.exists(serialization_dir3)
80
        os.makedirs(serialization_dir3)
81
        with open(os.path.join(serialization_dir3, "README.md"), "w") as f:
82
            f.write("TEST")
83

84
        with pytest.raises(ConfigurationError):
85
            find_learning_rate_model(
86
                self.params(),
87
                serialization_dir3,
88
                start_lr=1e-5,
89
                end_lr=1,
90
                num_batches=100,
91
                linear_steps=True,
92
                stopping_factor=None,
93
                force=False,
94
            )
95

96
        # ... unless you use the --force flag.
97
        find_learning_rate_model(
98
            self.params(),
99
            serialization_dir3,
100
            start_lr=1e-5,
101
            end_lr=1,
102
            num_batches=100,
103
            linear_steps=True,
104
            stopping_factor=None,
105
            force=True,
106
        )
107

108
    def test_find_learning_rate_args(self):
109
        parser = argparse.ArgumentParser(description="Testing")
110
        subparsers = parser.add_subparsers(title="Commands", metavar="")
111
        FindLearningRate().add_subparser(subparsers)
112

113
        for serialization_arg in ["-s", "--serialization-dir"]:
114
            raw_args = ["find-lr", "path/to/params", serialization_arg, "serialization_dir"]
115

116
            args = parser.parse_args(raw_args)
117

118
            assert args.func == find_learning_rate_from_args
119
            assert args.param_path == "path/to/params"
120
            assert args.serialization_dir == "serialization_dir"
121

122
        # config is required
123
        with pytest.raises(SystemExit) as cm:
124
            parser.parse_args(["find-lr", "-s", "serialization_dir"])
125
            assert cm.exception.code == 2  # argparse code for incorrect usage
126

127
        # serialization dir is required
128
        with pytest.raises(SystemExit) as cm:
129
            parser.parse_args(["find-lr", "path/to/params"])
130
            assert cm.exception.code == 2  # argparse code for incorrect usage
131

132
    @requires_multi_gpu
133
    def test_find_learning_rate_multi_gpu(self):
134
        params = self.params()
135
        del params["trainer"]["cuda_device"]
136
        params["distributed"] = Params({})
137
        params["distributed"]["cuda_devices"] = [0, 1]
138

139
        with pytest.raises(AssertionError) as execinfo:
140
            find_learning_rate_model(
141
                params,
142
                os.path.join(self.TEST_DIR, "test_find_learning_rate_multi_gpu"),
143
                start_lr=1e-5,
144
                end_lr=1,
145
                num_batches=100,
146
                linear_steps=True,
147
                stopping_factor=None,
148
                force=False,
149
            )
150
        assert "DistributedDataParallel" in str(execinfo.value)
151

152

153
class TestSearchLearningRate(AllenNlpTestCase):
154
    def setup_method(self):
155
        super().setup_method()
156
        params = Params(
157
            {
158
                "model": {
159
                    "type": "simple_tagger",
160
                    "text_field_embedder": {
161
                        "token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
162
                    },
163
                    "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
164
                },
165
                "dataset_reader": {"type": "sequence_tagging"},
166
                "train_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
167
                "validation_data_path": str(self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv"),
168
                "data_loader": {"batch_size": 2},
169
                "trainer": {"cuda_device": -1, "num_epochs": 2, "optimizer": "adam"},
170
            }
171
        )
172
        all_data_loaders = data_loaders_from_params(params)
173
        vocab = Vocabulary.from_params(
174
            params.pop("vocabulary", {}),
175
            instances=(
176
                instance
177
                for data_loader in all_data_loaders.values()
178
                for instance in data_loader.iter_instances()
179
            ),
180
        )
181
        model = Model.from_params(vocab=vocab, params=params.pop("model"))
182

183
        data_loader = all_data_loaders["train"]
184
        data_loader.index_with(vocab)
185

186
        trainer_params = params.pop("trainer")
187
        serialization_dir = os.path.join(self.TEST_DIR, "test_search_learning_rate")
188

189
        self.trainer = Trainer.from_params(
190
            model=model,
191
            serialization_dir=serialization_dir,
192
            data_loader=data_loader,
193
            params=trainer_params,
194
            validation_data=None,
195
            validation_iterator=None,
196
        )
197

198
    def test_search_learning_rate_with_num_batches_less_than_ten(self):
199
        with pytest.raises(ConfigurationError):
200
            search_learning_rate(self.trainer, num_batches=9)
201

202
    def test_search_learning_rate_linear_steps(self):
203
        learning_rates_losses = search_learning_rate(self.trainer, linear_steps=True)
204
        assert len(learning_rates_losses) > 1
205

206
    def test_search_learning_rate_without_stopping_factor(self):
207
        learning_rates, losses = search_learning_rate(
208
            self.trainer, num_batches=100, stopping_factor=None
209
        )
210
        assert len(learning_rates) == 101
211
        assert len(losses) == 101
212

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.