6
from allennlp.common import cached_transformers
7
from allennlp.common.checks import ConfigurationError
8
from allennlp.common.testing import AllenNlpTestCase
9
from transformers import AutoConfig, AutoModel
12
class TestCachedTransformers(AllenNlpTestCase):
13
def setup_method(self):
14
super().setup_method()
15
cached_transformers._clear_caches()
17
def teardown_method(self):
18
super().teardown_method()
19
cached_transformers._clear_caches()
21
def test_get_missing_from_cache_local_files_only(self):
22
with pytest.raises((OSError, ValueError)):
23
cached_transformers.get(
26
cache_dir=self.TEST_DIR,
27
local_files_only=True,
30
def clear_test_dir(self):
31
for f in os.listdir(str(self.TEST_DIR)):
32
os.remove(str(self.TEST_DIR) + "/" + f)
33
assert len(os.listdir(str(self.TEST_DIR))) == 0
35
def test_from_pretrained_avoids_weights_download_if_override_weights(self):
36
config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
37
# only download config because downloading pretrained weights in addition takes too long
38
transformer = AutoModel.from_config(
39
AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
41
transformer = AutoModel.from_config(config)
43
# clear cache directory
46
save_weights_path = str(self.TEST_DIR / "bert_weights.pth")
47
torch.save(transformer.state_dict(), save_weights_path)
49
override_transformer = cached_transformers.get(
50
"epwalsh/bert-xsmall-dummy",
52
override_weights_file=save_weights_path,
53
cache_dir=self.TEST_DIR,
55
# check that only three files were downloaded (filename.json, filename, filename.lock), for config.json
56
# if more than three files were downloaded, then model weights were also (incorrectly) downloaded
57
# NOTE: downloaded files are not explicitly detailed in Huggingface's public API,
58
# so this assertion could fail in the future
59
json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
60
assert len(json_fnames) == 1
61
json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
64
== "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
66
resource_id = os.path.splitext(json_fnames[0])[0]
67
assert set(os.listdir(str(self.TEST_DIR))) == set(
68
[json_fnames[0], resource_id, resource_id + ".lock", "bert_weights.pth"]
71
# check that override weights were loaded correctly
72
for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()):
73
assert p1.data.ne(p2.data).sum() == 0
75
def test_reinit_modules_no_op(self):
76
# Test the case where reinit_modules is None (default)
77
preinit_weights = torch.cat(
79
# Comparing all weights of the model is rather complicated, so arbitrarily
80
# compare the weights of attention module.
81
layer.attention.output.dense.weight
82
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
85
postinit_weights = torch.cat(
87
layer.attention.output.dense.weight
88
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
91
assert torch.equal(postinit_weights, preinit_weights)
93
def test_reinit_modules_with_layer_indices(self):
94
# Comparing all weights of the model is rather complicated, so arbitrarily compare the
95
# weights of attention module.
96
preinit_weights = torch.cat(
98
layer.attention.output.dense.weight
99
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
103
# Test the case when reinit_modules is a valid int.
104
postinit_weights = torch.cat(
106
layer.attention.output.dense.weight
107
for layer in cached_transformers.get(
108
"bert-base-cased", True, reinit_modules=2
112
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
113
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])
115
# Test the case when reinit_modules is a valid list of integers.
116
postinit_weights = torch.cat(
118
layer.attention.output.dense.weight
119
for layer in cached_transformers.get(
120
"bert-base-cased", True, reinit_modules=(10, 11)
124
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
125
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])
127
# Should raise a ValueError because reinit_modules contains at least one index that is
128
# greater than the models maximum number of layers
129
with pytest.raises(ValueError):
130
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=1000)
131
with pytest.raises(ValueError):
132
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, 1000))
133
# The argument cannot mix layer indices and regex strings.
134
with pytest.raises(ValueError):
135
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, "attentions"))
136
# This model has a non-standard structure, so if a layer index or list of layer indexes
137
# is provided, we raise a ConfigurationError.
138
with pytest.raises(ConfigurationError):
139
_ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=1)
140
with pytest.raises(ConfigurationError):
141
_ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=(1, 2))
143
def test_reinit_modules_with_regex_strings(self):
144
# Comparing all weights of the model is rather complicated, so arbitrarily compare the
145
# weights of wpe module.
146
reinit_module = "wpe"
147
# This MUST be a deep copy, otherwise the parameters will be re-initialized and the
149
preinit_weights = list(
150
cached_transformers.get("sshleifer/tiny-gpt2", True)
151
.get_submodule(reinit_module)
155
postinit_weights = list(
156
cached_transformers.get(
157
"sshleifer/tiny-gpt2",
159
reinit_modules=(reinit_module,),
161
.get_submodule(reinit_module)
165
(not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights))
168
def test_from_pretrained_no_load_weights(self):
169
_ = cached_transformers.get(
170
"epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR
172
# check that only three files were downloaded (filename.json, filename, filename.lock), for config.json
173
# if more than three files were downloaded, then model weights were also (incorrectly) downloaded
174
# NOTE: downloaded files are not explicitly detailed in Huggingface's public API,
175
# so this assertion could fail in the future
176
json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
177
assert len(json_fnames) == 1
178
json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
181
== "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
183
resource_id = os.path.splitext(json_fnames[0])[0]
184
assert set(os.listdir(str(self.TEST_DIR))) == set(
185
[json_fnames[0], resource_id, resource_id + ".lock"]
188
def test_from_pretrained_no_load_weights_local_config(self):
189
config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
190
self.clear_test_dir()
192
# Save config to file.
193
local_config_path = str(self.TEST_DIR / "local_config.json")
194
config.to_json_file(local_config_path, use_diff=False)
196
# Now load the model from the local config.
197
_ = cached_transformers.get(
198
local_config_path, False, load_weights=False, cache_dir=self.TEST_DIR
200
# Make sure no other files were downloaded.
201
assert os.listdir(str(self.TEST_DIR)) == ["local_config.json"]
203
def test_get_tokenizer_missing_from_cache_local_files_only(self):
204
with pytest.raises((OSError, ValueError)):
205
cached_transformers.get_tokenizer(
207
cache_dir=self.TEST_DIR,
208
local_files_only=True,