transformers

Форк
0
/
test_modeling_tf_utils.py 
729 строк · 32.1 Кб
1
# coding=utf-8
2
# Copyright 2019 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
from __future__ import annotations
18

19
import inspect
20
import json
21
import os
22
import random
23
import tempfile
24
import unittest
25
import unittest.mock as mock
26

27
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
28
from huggingface_hub.file_download import http_get
29
from requests.exceptions import HTTPError
30

31
from transformers import is_tf_available, is_torch_available
32
from transformers.configuration_utils import PretrainedConfig
33
from transformers.testing_utils import (  # noqa: F401
34
    TOKEN,
35
    USER,
36
    CaptureLogger,
37
    _tf_gpu_memory_limit,
38
    is_pt_tf_cross_test,
39
    is_staging_test,
40
    require_safetensors,
41
    require_tf,
42
    require_torch,
43
    slow,
44
)
45
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
46

47

48
logger = logging.get_logger(__name__)
49

50

51
if is_tf_available():
52
    import h5py
53
    import numpy as np
54
    import tensorflow as tf
55

56
    from transformers import (
57
        BertConfig,
58
        PreTrainedModel,
59
        PushToHubCallback,
60
        RagRetriever,
61
        TFBertForMaskedLM,
62
        TFBertForSequenceClassification,
63
        TFBertModel,
64
        TFPreTrainedModel,
65
        TFRagModel,
66
    )
67
    from transformers.modeling_tf_utils import keras, tf_shard_checkpoint, unpack_inputs
68
    from transformers.tf_utils import stable_softmax
69

70
    tf.config.experimental.enable_tensor_float_32_execution(False)
71

72
    if _tf_gpu_memory_limit is not None:
73
        gpus = tf.config.list_physical_devices("GPU")
74
        for gpu in gpus:
75
            # Restrict TensorFlow to only allocate x GB of memory on the GPUs
76
            try:
77
                tf.config.set_logical_device_configuration(
78
                    gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
79
                )
80
                logical_gpus = tf.config.list_logical_devices("GPU")
81
                print("Logical GPUs", logical_gpus)
82
            except RuntimeError as e:
83
                # Virtual devices must be set before GPUs have been initialized
84
                print(e)
85

86
if is_torch_available():
87
    from transformers import BertModel
88

89

90
@require_tf
91
class TFModelUtilsTest(unittest.TestCase):
92
    def test_cached_files_are_used_when_internet_is_down(self):
93
        # A mock response for an HTTP head request to emulate server down
94
        response_mock = mock.Mock()
95
        response_mock.status_code = 500
96
        response_mock.headers = {}
97
        response_mock.raise_for_status.side_effect = HTTPError
98
        response_mock.json.return_value = {}
99

100
        # Download this model to make sure it's in the cache.
101
        _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
102

103
        # Under the mock environment we get a 500 error when trying to reach the model.
104
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
105
            _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
106
            # This check we did call the fake head request
107
            mock_head.assert_called()
108

109
    def test_load_from_one_file(self):
110
        try:
111
            tmp_file = tempfile.mktemp()
112
            with open(tmp_file, "wb") as f:
113
                http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f)
114

115
            config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
116
            _ = TFBertModel.from_pretrained(tmp_file, config=config)
117
        finally:
118
            os.remove(tmp_file)
119

120
    def test_legacy_load_from_url(self):
121
        # This test is for deprecated behavior and can be removed in v5
122
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
123
        _ = TFBertModel.from_pretrained(
124
            "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config
125
        )
126

127
    # tests whether the unpack_inputs function behaves as expected
128
    def test_unpack_inputs(self):
129
        class DummyModel:
130
            def __init__(self):
131
                config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
132
                self.config = PretrainedConfig(**config_kwargs)
133
                self.main_input_name = "input_ids"
134

135
            @unpack_inputs
136
            def call(
137
                self,
138
                input_ids=None,
139
                past_key_values=None,
140
                output_attentions=None,
141
                output_hidden_states=None,
142
                return_dict=None,
143
            ):
144
                return input_ids, past_key_values, output_attentions, output_hidden_states, return_dict
145

146
            @unpack_inputs
147
            def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
148
                return pixel_values, output_attentions, output_hidden_states, return_dict
149

150
        dummy_model = DummyModel()
151
        input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
152
        past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
153
        pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
154

155
        # test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
156
        output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
157
        tf.debugging.assert_equal(output[0], input_ids)
158
        tf.debugging.assert_equal(output[1], past_key_values)
159
        self.assertFalse(output[2])
160
        self.assertFalse(output[3])
161
        self.assertFalse(output[4])
162

163
        # test case 2: Same as above, but with positional arguments.
164
        output = dummy_model.call(input_ids, past_key_values)
165
        tf.debugging.assert_equal(output[0], input_ids)
166
        tf.debugging.assert_equal(output[1], past_key_values)
167
        self.assertFalse(output[2])
168
        self.assertFalse(output[3])
169
        self.assertFalse(output[4])
170

171
        # test case 3: We can also pack everything in the first input.
172
        output = dummy_model.call(input_ids={"input_ids": input_ids, "past_key_values": past_key_values})
173
        tf.debugging.assert_equal(output[0], input_ids)
174
        tf.debugging.assert_equal(output[1], past_key_values)
175
        self.assertFalse(output[2])
176
        self.assertFalse(output[3])
177
        self.assertFalse(output[4])
178

179
        # test case 4: Explicit boolean arguments should override the config.
180
        output = dummy_model.call(
181
            input_ids=input_ids, past_key_values=past_key_values, output_attentions=False, return_dict=True
182
        )
183
        tf.debugging.assert_equal(output[0], input_ids)
184
        tf.debugging.assert_equal(output[1], past_key_values)
185
        self.assertFalse(output[2])
186
        self.assertFalse(output[3])
187
        self.assertTrue(output[4])
188

189
        # test case 5: Unexpected arguments should raise an exception.
190
        with self.assertRaises(ValueError):
191
            output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values, foo="bar")
192

193
        # test case 6: the decorator is independent from `main_input_name` -- it treats the first argument of the
194
        # decorated function as its main input.
195
        output = dummy_model.foo(pixel_values=pixel_values)
196
        tf.debugging.assert_equal(output[0], pixel_values)
197
        self.assertFalse(output[1])
198
        self.assertFalse(output[2])
199
        self.assertFalse(output[3])
200

201
    # Tests whether the stable softmax is stable on CPU, with and without XLA
202
    def test_xla_stable_softmax(self):
203
        large_penalty = -1e9
204
        n_tokens = 10
205
        batch_size = 8
206

207
        def masked_softmax(x, boolean_mask):
208
            numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
209
            masked_x = x + numerical_mask
210
            return stable_softmax(masked_x)
211

212
        xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
213
        xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
214
        x = tf.random.normal((batch_size, n_tokens))
215

216
        # Same outcome regardless of the boolean mask here
217
        masked_tokens = random.randint(0, n_tokens)
218
        boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
219

220
        # We can randomly mask a random numerical input OUTSIDE XLA
221
        numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
222
        masked_x = x + numerical_mask
223
        xla_out = xla_stable_softmax(masked_x)
224
        out = stable_softmax(masked_x)
225
        assert tf.experimental.numpy.allclose(xla_out, out)
226

227
        # The stable softmax has the same output as the original softmax
228
        unstable_out = tf.nn.softmax(masked_x)
229
        assert tf.experimental.numpy.allclose(unstable_out, out)
230

231
        # We can randomly mask a random numerical input INSIDE XLA
232
        xla_out = xla_masked_softmax(x, boolean_mask)
233
        out = masked_softmax(x, boolean_mask)
234
        assert tf.experimental.numpy.allclose(xla_out, out)
235

236
    def test_checkpoint_sharding_from_hub(self):
237
        model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
238
        # the model above is the same as the model below, just a sharded version.
239
        ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
240
        for p1, p2 in zip(model.weights, ref_model.weights):
241
            assert np.allclose(p1.numpy(), p2.numpy())
242

243
    def test_sharded_checkpoint_with_prefix(self):
244
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
245
        sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
246
        for p1, p2 in zip(model.weights, sharded_model.weights):
247
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
248
            self.assertTrue(p1.name.startswith("a/b/"))
249
            self.assertTrue(p2.name.startswith("a/b/"))
250

251
    def test_sharded_checkpoint_transfer(self):
252
        # If this doesn't throw an error then the test passes
253
        TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
254

255
    @is_pt_tf_cross_test
256
    def test_checkpoint_sharding_local_from_pt(self):
257
        with tempfile.TemporaryDirectory() as tmp_dir:
258
            _ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
259
            model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
260
            # the model above is the same as the model below, just a sharded pytorch version.
261
            ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
262
            for p1, p2 in zip(model.weights, ref_model.weights):
263
                assert np.allclose(p1.numpy(), p2.numpy())
264

265
    @is_pt_tf_cross_test
266
    def test_checkpoint_loading_with_prefix_from_pt(self):
267
        model = TFBertModel.from_pretrained(
268
            "hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b"
269
        )
270
        ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
271
        for p1, p2 in zip(model.weights, ref_model.weights):
272
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
273
            self.assertTrue(p1.name.startswith("a/b/"))
274

275
    @is_pt_tf_cross_test
276
    def test_checkpoint_sharding_hub_from_pt(self):
277
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
278
        # the model above is the same as the model below, just a sharded pytorch version.
279
        ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
280
        for p1, p2 in zip(model.weights, ref_model.weights):
281
            assert np.allclose(p1.numpy(), p2.numpy())
282

283
    def test_shard_checkpoint(self):
284
        # This is the model we will use, total size 340,000 bytes.
285
        model = keras.Sequential(
286
            [
287
                keras.layers.Dense(200, use_bias=False),  # size 80,000
288
                keras.layers.Dense(200, use_bias=False),  # size 160,000
289
                keras.layers.Dense(100, use_bias=False),  # size 80,000
290
                keras.layers.Dense(50, use_bias=False),  # size 20,000
291
            ]
292
        )
293
        inputs = tf.zeros((1, 100), dtype=tf.float32)
294
        model(inputs)
295
        weights = model.weights
296
        weights_dict = {w.name: w for w in weights}
297
        with self.subTest("No shard when max size is bigger than model size"):
298
            shards, index = tf_shard_checkpoint(weights)
299
            self.assertIsNone(index)
300
            self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
301

302
        with self.subTest("Test sharding, no weights bigger than max size"):
303
            shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
304
            # Split is first two layers then last two.
305
            self.assertDictEqual(
306
                index,
307
                {
308
                    "metadata": {"total_size": 340000},
309
                    "weight_map": {
310
                        "dense/kernel:0": "tf_model-00001-of-00002.h5",
311
                        "dense_1/kernel:0": "tf_model-00001-of-00002.h5",
312
                        "dense_2/kernel:0": "tf_model-00002-of-00002.h5",
313
                        "dense_3/kernel:0": "tf_model-00002-of-00002.h5",
314
                    },
315
                },
316
            )
317

318
            shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
319
            shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
320
            self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
321

322
        with self.subTest("Test sharding with weights bigger than max size"):
323
            shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
324
            # Split is first layer, second layer then last 2.
325
            self.assertDictEqual(
326
                index,
327
                {
328
                    "metadata": {"total_size": 340000},
329
                    "weight_map": {
330
                        "dense/kernel:0": "tf_model-00001-of-00003.h5",
331
                        "dense_1/kernel:0": "tf_model-00002-of-00003.h5",
332
                        "dense_2/kernel:0": "tf_model-00003-of-00003.h5",
333
                        "dense_3/kernel:0": "tf_model-00003-of-00003.h5",
334
                    },
335
                },
336
            )
337

338
            shard1 = [weights_dict["dense/kernel:0"]]
339
            shard2 = [weights_dict["dense_1/kernel:0"]]
340
            shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
341
            self.assertDictEqual(
342
                shards,
343
                {
344
                    "tf_model-00001-of-00003.h5": shard1,
345
                    "tf_model-00002-of-00003.h5": shard2,
346
                    "tf_model-00003-of-00003.h5": shard3,
347
                },
348
            )
349

350
    @slow
351
    def test_special_layer_name_sharding(self):
352
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
353
        model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
354

355
        with tempfile.TemporaryDirectory() as tmp_dir:
356
            for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
357
                model.save_pretrained(tmp_dir, max_shard_size=max_size)
358
                ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
359
                for p1, p2 in zip(model.weights, ref_model.weights):
360
                    assert np.allclose(p1.numpy(), p2.numpy())
361

362
    def test_checkpoint_sharding_local(self):
363
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
364

365
        with tempfile.TemporaryDirectory() as tmp_dir:
366
            # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
367
            for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
368
                model.save_pretrained(tmp_dir, max_shard_size=max_size)
369

370
                # Get each shard file and its size
371
                shard_to_size = {}
372
                for shard in os.listdir(tmp_dir):
373
                    if shard.endswith(".h5"):
374
                        shard_file = os.path.join(tmp_dir, shard)
375
                        shard_to_size[shard_file] = os.path.getsize(shard_file)
376

377
                index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
378
                # Check there is an index but no regular weight file
379
                self.assertTrue(os.path.isfile(index_file))
380
                self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
381

382
                # Check a file is bigger than max_size only when it has a single weight
383
                for shard_file, size in shard_to_size.items():
384
                    if max_size.endswith("kiB"):
385
                        max_size_int = int(max_size[:-3]) * 2**10
386
                    else:
387
                        max_size_int = int(max_size[:-2]) * 10**3
388
                    # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
389
                    # the size asked for (since we count parameters)
390
                    if size >= max_size_int + 50000:
391
                        with h5py.File(shard_file, "r") as state_file:
392
                            self.assertEqual(len(state_file), 1)
393

394
                # Check the index and the shard files found match
395
                with open(index_file, "r", encoding="utf-8") as f:
396
                    index = json.loads(f.read())
397

398
                all_shards = set(index["weight_map"].values())
399
                shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
400
                self.assertSetEqual(all_shards, shards_found)
401

402
                # Finally, check the model can be reloaded
403
                new_model = TFBertModel.from_pretrained(tmp_dir)
404

405
                model.build_in_name_scope()
406
                new_model.build_in_name_scope()
407

408
                for p1, p2 in zip(model.weights, new_model.weights):
409
                    self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
410

411
    @slow
412
    def test_save_pretrained_signatures(self):
413
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
414

415
        # Short custom TF signature function.
416
        # `input_signature` is specific to BERT.
417
        @tf.function(
418
            input_signature=[
419
                [
420
                    tf.TensorSpec([None, None], tf.int32, name="input_ids"),
421
                    tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
422
                    tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
423
                ]
424
            ]
425
        )
426
        def serving_fn(input):
427
            return model(input)
428

429
        # Using default signature (default behavior) overrides 'serving_default'
430
        with tempfile.TemporaryDirectory() as tmp_dir:
431
            model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
432
            model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
433
            self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
434

435
        # Providing custom signature function
436
        with tempfile.TemporaryDirectory() as tmp_dir:
437
            model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
438
            model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
439
            self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
440

441
        # Providing multiple custom signature function
442
        with tempfile.TemporaryDirectory() as tmp_dir:
443
            model.save_pretrained(
444
                tmp_dir,
445
                saved_model=True,
446
                signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
447
            )
448
            model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
449
            self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
450
            self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
451

452
    @require_safetensors
453
    def test_safetensors_save_and_load(self):
454
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
455
        with tempfile.TemporaryDirectory() as tmp_dir:
456
            model.save_pretrained(tmp_dir, safe_serialization=True)
457
            # No tf_model.h5 file, only a model.safetensors
458
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
459
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
460

461
            new_model = TFBertModel.from_pretrained(tmp_dir)
462

463
            # Check models are equal
464
            for p1, p2 in zip(model.weights, new_model.weights):
465
                self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
466

467
    @is_pt_tf_cross_test
468
    def test_safetensors_save_and_load_pt_to_tf(self):
469
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
470
        pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
471
        with tempfile.TemporaryDirectory() as tmp_dir:
472
            pt_model.save_pretrained(tmp_dir, safe_serialization=True)
473
            # Check we have a model.safetensors file
474
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
475

476
            new_model = TFBertModel.from_pretrained(tmp_dir)
477

478
            # Check models are equal
479
            for p1, p2 in zip(model.weights, new_model.weights):
480
                self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
481

482
    @require_safetensors
483
    def test_safetensors_load_from_hub(self):
484
        tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
485

486
        # Can load from the TF-formatted checkpoint
487
        safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
488

489
        # Check models are equal
490
        for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
491
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
492

493
        # Can load from the PyTorch-formatted checkpoint
494
        safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
495

496
        # Check models are equal
497
        for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
498
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
499

500
    @require_safetensors
501
    def test_safetensors_tf_from_tf(self):
502
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
503

504
        with tempfile.TemporaryDirectory() as tmp_dir:
505
            model.save_pretrained(tmp_dir, safe_serialization=True)
506
            new_model = TFBertModel.from_pretrained(tmp_dir)
507

508
        for p1, p2 in zip(model.weights, new_model.weights):
509
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
510

511
    @require_safetensors
512
    @is_pt_tf_cross_test
513
    def test_safetensors_tf_from_torch(self):
514
        hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
515
        model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
516

517
        with tempfile.TemporaryDirectory() as tmp_dir:
518
            model.save_pretrained(tmp_dir, safe_serialization=True)
519
            new_model = TFBertModel.from_pretrained(tmp_dir)
520

521
        for p1, p2 in zip(hub_model.weights, new_model.weights):
522
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
523

524
    @require_safetensors
525
    def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self):
526
        with tempfile.TemporaryDirectory() as tmp_dir:
527
            path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir)
528

529
            # This should not raise even if there are two types of sharded weights
530
            TFBertModel.from_pretrained(path)
531

532
    @require_safetensors
533
    def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
534
        # This should not raise even if there are two types of sharded weights
535
        # This should discard the safetensors weights in favor of the .h5 sharded weights
536
        TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
537

538
    @require_safetensors
539
    def test_safetensors_load_from_local(self):
540
        """
541
        This test checks that we can load safetensors from a checkpoint that only has those on the Hub
542
        """
543
        with tempfile.TemporaryDirectory() as tmp:
544
            location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
545
            tf_model = TFBertModel.from_pretrained(location)
546

547
        with tempfile.TemporaryDirectory() as tmp:
548
            location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
549
            safetensors_model = TFBertModel.from_pretrained(location)
550

551
        for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
552
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
553

554
    @require_safetensors
555
    def test_safetensors_load_from_hub_from_safetensors_pt(self):
556
        """
557
        This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
558
        saved in the "pt" format.
559
        """
560
        tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
561

562
        # Can load from the PyTorch-formatted checkpoint
563
        safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
564
        for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
565
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
566

567
    @require_safetensors
568
    def test_safetensors_load_from_local_from_safetensors_pt(self):
569
        """
570
        This test checks that we can load safetensors from a local checkpoint that only has those
571
        saved in the "pt" format.
572
        """
573
        with tempfile.TemporaryDirectory() as tmp:
574
            location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
575
            tf_model = TFBertModel.from_pretrained(location)
576

577
        # Can load from the PyTorch-formatted checkpoint
578
        with tempfile.TemporaryDirectory() as tmp:
579
            location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
580
            safetensors_model = TFBertModel.from_pretrained(location)
581

582
        for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
583
            self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
584

585
    @require_safetensors
586
    def test_safetensors_load_from_hub_h5_before_safetensors(self):
587
        """
588
        This test checks that we'll first download h5 weights before safetensors
589
        The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
590
        """
591
        TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
592

593
    @require_safetensors
594
    def test_safetensors_load_from_local_h5_before_safetensors(self):
595
        """
596
        This test checks that we'll first download h5 weights before safetensors
597
        The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
598
        """
599
        with tempfile.TemporaryDirectory() as tmp:
600
            location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
601
            TFBertModel.from_pretrained(location)
602

603

604
@require_tf
605
@is_staging_test
606
class TFModelPushToHubTester(unittest.TestCase):
607
    @classmethod
608
    def setUpClass(cls):
609
        cls._token = TOKEN
610
        HfFolder.save_token(TOKEN)
611

612
    @classmethod
613
    def tearDownClass(cls):
614
        try:
615
            delete_repo(token=cls._token, repo_id="test-model-tf")
616
        except HTTPError:
617
            pass
618

619
        try:
620
            delete_repo(token=cls._token, repo_id="test-model-tf-callback")
621
        except HTTPError:
622
            pass
623

624
        try:
625
            delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
626
        except HTTPError:
627
            pass
628

629
    def test_push_to_hub(self):
630
        config = BertConfig(
631
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
632
        )
633
        model = TFBertModel(config)
634
        # Make sure model is properly initialized
635
        model.build_in_name_scope()
636

637
        logging.set_verbosity_info()
638
        logger = logging.get_logger("transformers.utils.hub")
639
        with CaptureLogger(logger) as cl:
640
            model.push_to_hub("test-model-tf", token=self._token)
641
        logging.set_verbosity_warning()
642
        # Check the model card was created and uploaded.
643
        self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
644

645
        new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
646
        models_equal = True
647
        for p1, p2 in zip(model.weights, new_model.weights):
648
            if not tf.math.reduce_all(p1 == p2):
649
                models_equal = False
650
                break
651
        self.assertTrue(models_equal)
652

653
        # Reset repo
654
        delete_repo(token=self._token, repo_id="test-model-tf")
655

656
        # Push to hub via save_pretrained
657
        with tempfile.TemporaryDirectory() as tmp_dir:
658
            model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, token=self._token)
659

660
        new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
661
        models_equal = True
662
        for p1, p2 in zip(model.weights, new_model.weights):
663
            if not tf.math.reduce_all(p1 == p2):
664
                models_equal = False
665
                break
666
        self.assertTrue(models_equal)
667

668
    @is_pt_tf_cross_test
669
    def test_push_to_hub_callback(self):
670
        config = BertConfig(
671
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
672
        )
673
        model = TFBertForMaskedLM(config)
674
        model.compile()
675

676
        with tempfile.TemporaryDirectory() as tmp_dir:
677
            push_to_hub_callback = PushToHubCallback(
678
                output_dir=tmp_dir,
679
                hub_model_id="test-model-tf-callback",
680
                hub_token=self._token,
681
            )
682
            model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
683

684
        new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
685
        models_equal = True
686
        for p1, p2 in zip(model.weights, new_model.weights):
687
            if not tf.math.reduce_all(p1 == p2):
688
                models_equal = False
689
                break
690
        self.assertTrue(models_equal)
691

692
        tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
693
        tf_push_to_hub_params.pop("base_model_card_args")
694
        pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
695
        pt_push_to_hub_params.pop("deprecated_kwargs")
696
        self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
697

698
    def test_push_to_hub_in_organization(self):
699
        config = BertConfig(
700
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
701
        )
702
        model = TFBertModel(config)
703
        # Make sure model is properly initialized
704
        model.build_in_name_scope()
705

706
        model.push_to_hub("valid_org/test-model-tf-org", token=self._token)
707

708
        new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
709
        models_equal = True
710
        for p1, p2 in zip(model.weights, new_model.weights):
711
            if not tf.math.reduce_all(p1 == p2):
712
                models_equal = False
713
                break
714
        self.assertTrue(models_equal)
715

716
        # Reset repo
717
        delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
718

719
        # Push to hub via save_pretrained
720
        with tempfile.TemporaryDirectory() as tmp_dir:
721
            model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-tf-org")
722

723
        new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
724
        models_equal = True
725
        for p1, p2 in zip(model.weights, new_model.weights):
726
            if not tf.math.reduce_all(p1 == p2):
727
                models_equal = False
728
                break
729
        self.assertTrue(models_equal)
730

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.