allennlp

Форк
0
/
cached_transformers_test.py 
209 строк · 9.0 Кб
1
import json
2
import os
3

4
import pytest
5
import torch
6
from allennlp.common import cached_transformers
7
from allennlp.common.checks import ConfigurationError
8
from allennlp.common.testing import AllenNlpTestCase
9
from transformers import AutoConfig, AutoModel
10

11

12
class TestCachedTransformers(AllenNlpTestCase):
13
    def setup_method(self):
14
        super().setup_method()
15
        cached_transformers._clear_caches()
16

17
    def teardown_method(self):
18
        super().teardown_method()
19
        cached_transformers._clear_caches()
20

21
    def test_get_missing_from_cache_local_files_only(self):
22
        with pytest.raises((OSError, ValueError)):
23
            cached_transformers.get(
24
                "bert-base-uncased",
25
                True,
26
                cache_dir=self.TEST_DIR,
27
                local_files_only=True,
28
            )
29

30
    def clear_test_dir(self):
31
        for f in os.listdir(str(self.TEST_DIR)):
32
            os.remove(str(self.TEST_DIR) + "/" + f)
33
        assert len(os.listdir(str(self.TEST_DIR))) == 0
34

35
    def test_from_pretrained_avoids_weights_download_if_override_weights(self):
36
        config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
37
        # only download config because downloading pretrained weights in addition takes too long
38
        transformer = AutoModel.from_config(
39
            AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
40
        )
41
        transformer = AutoModel.from_config(config)
42

43
        # clear cache directory
44
        self.clear_test_dir()
45

46
        save_weights_path = str(self.TEST_DIR / "bert_weights.pth")
47
        torch.save(transformer.state_dict(), save_weights_path)
48

49
        override_transformer = cached_transformers.get(
50
            "epwalsh/bert-xsmall-dummy",
51
            False,
52
            override_weights_file=save_weights_path,
53
            cache_dir=self.TEST_DIR,
54
        )
55
        # check that only three files were downloaded (filename.json, filename, filename.lock), for config.json
56
        # if more than three files were downloaded, then model weights were also (incorrectly) downloaded
57
        # NOTE: downloaded files are not explicitly detailed in Huggingface's public API,
58
        # so this assertion could fail in the future
59
        json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
60
        assert len(json_fnames) == 1
61
        json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
62
        assert (
63
            json_data["url"]
64
            == "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
65
        )
66
        resource_id = os.path.splitext(json_fnames[0])[0]
67
        assert set(os.listdir(str(self.TEST_DIR))) == set(
68
            [json_fnames[0], resource_id, resource_id + ".lock", "bert_weights.pth"]
69
        )
70

71
        # check that override weights were loaded correctly
72
        for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()):
73
            assert p1.data.ne(p2.data).sum() == 0
74

75
    def test_reinit_modules_no_op(self):
76
        # Test the case where reinit_modules is None (default)
77
        preinit_weights = torch.cat(
78
            [
79
                # Comparing all weights of the model is rather complicated, so arbitrarily
80
                # compare the weights of attention module.
81
                layer.attention.output.dense.weight
82
                for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
83
            ]
84
        )
85
        postinit_weights = torch.cat(
86
            [
87
                layer.attention.output.dense.weight
88
                for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
89
            ]
90
        )
91
        assert torch.equal(postinit_weights, preinit_weights)
92

93
    def test_reinit_modules_with_layer_indices(self):
94
        # Comparing all weights of the model is rather complicated, so arbitrarily compare the
95
        # weights of attention module.
96
        preinit_weights = torch.cat(
97
            [
98
                layer.attention.output.dense.weight
99
                for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
100
            ]
101
        )
102

103
        # Test the case when reinit_modules is a valid int.
104
        postinit_weights = torch.cat(
105
            [
106
                layer.attention.output.dense.weight
107
                for layer in cached_transformers.get(
108
                    "bert-base-cased", True, reinit_modules=2
109
                ).encoder.layer
110
            ]
111
        )
112
        assert torch.equal(postinit_weights[:10], preinit_weights[:10])
113
        assert not torch.equal(postinit_weights[10:], preinit_weights[10:])
114

115
        # Test the case when reinit_modules is a valid list of integers.
116
        postinit_weights = torch.cat(
117
            [
118
                layer.attention.output.dense.weight
119
                for layer in cached_transformers.get(
120
                    "bert-base-cased", True, reinit_modules=(10, 11)
121
                ).encoder.layer
122
            ]
123
        )
124
        assert torch.equal(postinit_weights[:10], preinit_weights[:10])
125
        assert not torch.equal(postinit_weights[10:], preinit_weights[10:])
126

127
        # Should raise a ValueError because reinit_modules contains at least one index that is
128
        # greater than the models maximum number of layers
129
        with pytest.raises(ValueError):
130
            _ = cached_transformers.get("bert-base-cased", True, reinit_modules=1000)
131
        with pytest.raises(ValueError):
132
            _ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, 1000))
133
        # The argument cannot mix layer indices and regex strings.
134
        with pytest.raises(ValueError):
135
            _ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, "attentions"))
136
        # This model has a non-standard structure, so if a layer index or list of layer indexes
137
        # is provided, we raise a ConfigurationError.
138
        with pytest.raises(ConfigurationError):
139
            _ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=1)
140
        with pytest.raises(ConfigurationError):
141
            _ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=(1, 2))
142

143
    def test_reinit_modules_with_regex_strings(self):
144
        # Comparing all weights of the model is rather complicated, so arbitrarily compare the
145
        # weights of wpe module.
146
        reinit_module = "wpe"
147
        # This MUST be a deep copy, otherwise the parameters will be re-initialized and the
148
        # test will break.
149
        preinit_weights = list(
150
            cached_transformers.get("sshleifer/tiny-gpt2", True)
151
            .get_submodule(reinit_module)
152
            .parameters()
153
        )
154

155
        postinit_weights = list(
156
            cached_transformers.get(
157
                "sshleifer/tiny-gpt2",
158
                True,
159
                reinit_modules=(reinit_module,),
160
            )
161
            .get_submodule(reinit_module)
162
            .parameters()
163
        )
164
        assert all(
165
            (not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights))
166
        )
167

168
    def test_from_pretrained_no_load_weights(self):
169
        _ = cached_transformers.get(
170
            "epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR
171
        )
172
        # check that only three files were downloaded (filename.json, filename, filename.lock), for config.json
173
        # if more than three files were downloaded, then model weights were also (incorrectly) downloaded
174
        # NOTE: downloaded files are not explicitly detailed in Huggingface's public API,
175
        # so this assertion could fail in the future
176
        json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
177
        assert len(json_fnames) == 1
178
        json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
179
        assert (
180
            json_data["url"]
181
            == "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
182
        )
183
        resource_id = os.path.splitext(json_fnames[0])[0]
184
        assert set(os.listdir(str(self.TEST_DIR))) == set(
185
            [json_fnames[0], resource_id, resource_id + ".lock"]
186
        )
187

188
    def test_from_pretrained_no_load_weights_local_config(self):
189
        config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
190
        self.clear_test_dir()
191

192
        # Save config to file.
193
        local_config_path = str(self.TEST_DIR / "local_config.json")
194
        config.to_json_file(local_config_path, use_diff=False)
195

196
        # Now load the model from the local config.
197
        _ = cached_transformers.get(
198
            local_config_path, False, load_weights=False, cache_dir=self.TEST_DIR
199
        )
200
        # Make sure no other files were downloaded.
201
        assert os.listdir(str(self.TEST_DIR)) == ["local_config.json"]
202

203
    def test_get_tokenizer_missing_from_cache_local_files_only(self):
204
        with pytest.raises((OSError, ValueError)):
205
            cached_transformers.get_tokenizer(
206
                "bert-base-uncased",
207
                cache_dir=self.TEST_DIR,
208
                local_files_only=True,
209
            )
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.