transformers

Форк
0
/
test_configuration_utils.py 
314 строк · 12.8 Кб
1
# coding=utf-8
2
# Copyright 2019 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import json
17
import os
18
import shutil
19
import sys
20
import tempfile
21
import unittest
22
import unittest.mock as mock
23
from pathlib import Path
24

25
from huggingface_hub import HfFolder, delete_repo
26
from requests.exceptions import HTTPError
27

28
from transformers import AutoConfig, BertConfig, GPT2Config
29
from transformers.configuration_utils import PretrainedConfig
30
from transformers.testing_utils import TOKEN, USER, is_staging_test
31

32

33
sys.path.append(str(Path(__file__).parent.parent / "utils"))
34

35
from test_module.custom_configuration import CustomConfig  # noqa E402
36

37

38
config_common_kwargs = {
39
    "return_dict": False,
40
    "output_hidden_states": True,
41
    "output_attentions": True,
42
    "torchscript": True,
43
    "torch_dtype": "float16",
44
    "use_bfloat16": True,
45
    "tf_legacy_loss": True,
46
    "pruned_heads": {"a": 1},
47
    "tie_word_embeddings": False,
48
    "is_decoder": True,
49
    "cross_attention_hidden_size": 128,
50
    "add_cross_attention": True,
51
    "tie_encoder_decoder": True,
52
    "max_length": 50,
53
    "min_length": 3,
54
    "do_sample": True,
55
    "early_stopping": True,
56
    "num_beams": 3,
57
    "num_beam_groups": 3,
58
    "diversity_penalty": 0.5,
59
    "temperature": 2.0,
60
    "top_k": 10,
61
    "top_p": 0.7,
62
    "typical_p": 0.2,
63
    "repetition_penalty": 0.8,
64
    "length_penalty": 0.8,
65
    "no_repeat_ngram_size": 5,
66
    "encoder_no_repeat_ngram_size": 5,
67
    "bad_words_ids": [1, 2, 3],
68
    "num_return_sequences": 3,
69
    "chunk_size_feed_forward": 5,
70
    "output_scores": True,
71
    "return_dict_in_generate": True,
72
    "forced_bos_token_id": 2,
73
    "forced_eos_token_id": 3,
74
    "remove_invalid_values": True,
75
    "architectures": ["BertModel"],
76
    "finetuning_task": "translation",
77
    "id2label": {0: "label"},
78
    "label2id": {"label": "0"},
79
    "tokenizer_class": "BertTokenizerFast",
80
    "prefix": "prefix",
81
    "bos_token_id": 6,
82
    "pad_token_id": 7,
83
    "eos_token_id": 8,
84
    "sep_token_id": 9,
85
    "decoder_start_token_id": 10,
86
    "exponential_decay_length_penalty": (5, 1.01),
87
    "suppress_tokens": [0, 1],
88
    "begin_suppress_tokens": 2,
89
    "task_specific_params": {"translation": "some_params"},
90
    "problem_type": "regression",
91
}
92

93

94
@is_staging_test
95
class ConfigPushToHubTester(unittest.TestCase):
96
    @classmethod
97
    def setUpClass(cls):
98
        cls._token = TOKEN
99
        HfFolder.save_token(TOKEN)
100

101
    @classmethod
102
    def tearDownClass(cls):
103
        try:
104
            delete_repo(token=cls._token, repo_id="test-config")
105
        except HTTPError:
106
            pass
107

108
        try:
109
            delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
110
        except HTTPError:
111
            pass
112

113
        try:
114
            delete_repo(token=cls._token, repo_id="test-dynamic-config")
115
        except HTTPError:
116
            pass
117

118
    def test_push_to_hub(self):
119
        config = BertConfig(
120
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
121
        )
122
        config.push_to_hub("test-config", token=self._token)
123

124
        new_config = BertConfig.from_pretrained(f"{USER}/test-config")
125
        for k, v in config.to_dict().items():
126
            if k != "transformers_version":
127
                self.assertEqual(v, getattr(new_config, k))
128

129
        # Reset repo
130
        delete_repo(token=self._token, repo_id="test-config")
131

132
        # Push to hub via save_pretrained
133
        with tempfile.TemporaryDirectory() as tmp_dir:
134
            config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, token=self._token)
135

136
        new_config = BertConfig.from_pretrained(f"{USER}/test-config")
137
        for k, v in config.to_dict().items():
138
            if k != "transformers_version":
139
                self.assertEqual(v, getattr(new_config, k))
140

141
    def test_push_to_hub_in_organization(self):
142
        config = BertConfig(
143
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
144
        )
145
        config.push_to_hub("valid_org/test-config-org", token=self._token)
146

147
        new_config = BertConfig.from_pretrained("valid_org/test-config-org")
148
        for k, v in config.to_dict().items():
149
            if k != "transformers_version":
150
                self.assertEqual(v, getattr(new_config, k))
151

152
        # Reset repo
153
        delete_repo(token=self._token, repo_id="valid_org/test-config-org")
154

155
        # Push to hub via save_pretrained
156
        with tempfile.TemporaryDirectory() as tmp_dir:
157
            config.save_pretrained(tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, token=self._token)
158

159
        new_config = BertConfig.from_pretrained("valid_org/test-config-org")
160
        for k, v in config.to_dict().items():
161
            if k != "transformers_version":
162
                self.assertEqual(v, getattr(new_config, k))
163

164
    def test_push_to_hub_dynamic_config(self):
165
        CustomConfig.register_for_auto_class()
166
        config = CustomConfig(attribute=42)
167

168
        config.push_to_hub("test-dynamic-config", token=self._token)
169

170
        # This has added the proper auto_map field to the config
171
        self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
172

173
        new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
174
        # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
175
        self.assertEqual(new_config.__class__.__name__, "CustomConfig")
176
        self.assertEqual(new_config.attribute, 42)
177

178

179
class ConfigTestUtils(unittest.TestCase):
180
    def test_config_from_string(self):
181
        c = GPT2Config()
182

183
        # attempt to modify each of int/float/bool/str config records and verify they were updated
184
        n_embd = c.n_embd + 1  # int
185
        resid_pdrop = c.resid_pdrop + 1.0  # float
186
        scale_attn_weights = not c.scale_attn_weights  # bool
187
        summary_type = c.summary_type + "foo"  # str
188
        c.update_from_string(
189
            f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
190
        )
191
        self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
192
        self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
193
        self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
194
        self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
195

196
    def test_config_common_kwargs_is_complete(self):
197
        base_config = PretrainedConfig()
198
        missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
199
        # If this part of the test fails, you have arguments to addin config_common_kwargs above.
200
        self.assertListEqual(
201
            missing_keys,
202
            [
203
                "is_encoder_decoder",
204
                "_name_or_path",
205
                "_commit_hash",
206
                "_attn_implementation_internal",
207
                "transformers_version",
208
            ],
209
        )
210
        keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
211
        if len(keys_with_defaults) > 0:
212
            raise ValueError(
213
                "The following keys are set with the default values in"
214
                " `test_configuration_common.config_common_kwargs` pick another value for them:"
215
                f" {', '.join(keys_with_defaults)}."
216
            )
217

218
    def test_nested_config_load_from_dict(self):
219
        config = AutoConfig.from_pretrained(
220
            "hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
221
        )
222
        self.assertNotIsInstance(config.text_config, dict)
223
        self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
224

225
    def test_from_pretrained_subfolder(self):
226
        with self.assertRaises(OSError):
227
            # config is in subfolder, the following should not work without specifying the subfolder
228
            _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
229

230
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
231

232
        self.assertIsNotNone(config)
233

234
    def test_cached_files_are_used_when_internet_is_down(self):
235
        # A mock response for an HTTP head request to emulate server down
236
        response_mock = mock.Mock()
237
        response_mock.status_code = 500
238
        response_mock.headers = {}
239
        response_mock.raise_for_status.side_effect = HTTPError
240
        response_mock.json.return_value = {}
241

242
        # Download this model to make sure it's in the cache.
243
        _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
244

245
        # Under the mock environment we get a 500 error when trying to reach the model.
246
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
247
            _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
248
            # This check we did call the fake head request
249
            mock_head.assert_called()
250

251
    def test_legacy_load_from_url(self):
252
        # This test is for deprecated behavior and can be removed in v5
253
        _ = BertConfig.from_pretrained(
254
            "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/config.json"
255
        )
256

257
    def test_local_versioning(self):
258
        configuration = AutoConfig.from_pretrained("google-bert/bert-base-cased")
259
        configuration.configuration_files = ["config.4.0.0.json"]
260

261
        with tempfile.TemporaryDirectory() as tmp_dir:
262
            configuration.save_pretrained(tmp_dir)
263
            configuration.hidden_size = 2
264
            json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
265

266
            # This should pick the new configuration file as the version of Transformers is > 4.0.0
267
            new_configuration = AutoConfig.from_pretrained(tmp_dir)
268
            self.assertEqual(new_configuration.hidden_size, 2)
269

270
            # Will need to be adjusted if we reach v42 and this test is still here.
271
            # Should pick the old configuration file as the version of Transformers is < 4.42.0
272
            configuration.configuration_files = ["config.42.0.0.json"]
273
            configuration.hidden_size = 768
274
            configuration.save_pretrained(tmp_dir)
275
            shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
276
            new_configuration = AutoConfig.from_pretrained(tmp_dir)
277
            self.assertEqual(new_configuration.hidden_size, 768)
278

279
    def test_repo_versioning_before(self):
280
        # This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
281
        repo = "hf-internal-testing/test-two-configs"
282

283
        import transformers as new_transformers
284

285
        new_transformers.configuration_utils.__version__ = "v4.0.0"
286
        new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
287
            repo, return_unused_kwargs=True
288
        )
289
        self.assertEqual(new_configuration.hidden_size, 2)
290
        # This checks `_configuration_file` ia not kept in the kwargs by mistake.
291
        self.assertDictEqual(kwargs, {})
292

293
        # Testing an older version by monkey-patching the version in the module it's used.
294
        import transformers as old_transformers
295

296
        old_transformers.configuration_utils.__version__ = "v3.0.0"
297
        old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
298
        self.assertEqual(old_configuration.hidden_size, 768)
299

300
    def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
301
        config = BertConfig(min_length=3)  # `min_length = 3` is a non-default generation kwarg
302
        with tempfile.TemporaryDirectory() as tmp_dir:
303
            with self.assertLogs("transformers.configuration_utils", level="WARNING") as logs:
304
                config.save_pretrained(tmp_dir)
305
            self.assertEqual(len(logs.output), 1)
306
            self.assertIn("min_length", logs.output[0])
307

308
    def test_has_non_default_generation_parameters(self):
309
        config = BertConfig()
310
        self.assertFalse(config._has_non_default_generation_parameters())
311
        config = BertConfig(min_length=3)
312
        self.assertTrue(config._has_non_default_generation_parameters())
313
        config = BertConfig(min_length=0)  # `min_length = 0` is a default generation kwarg
314
        self.assertFalse(config._has_non_default_generation_parameters())
315

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.