transformers

Форк
0
/
test_modeling_utils.py 
2086 строк · 91.2 Кб
1
# coding=utf-8
2
# Copyright 2019 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import copy
16
import gc
17
import glob
18
import json
19
import os
20
import os.path
21
import sys
22
import tempfile
23
import unittest
24
import unittest.mock as mock
25
import uuid
26
from pathlib import Path
27

28
import requests
29
from huggingface_hub import HfApi, HfFolder, delete_repo
30
from huggingface_hub.file_download import http_get
31
from pytest import mark
32
from requests.exceptions import HTTPError
33

34
from transformers import (
35
    AutoConfig,
36
    AutoModel,
37
    AutoModelForSequenceClassification,
38
    OwlViTForObjectDetection,
39
    PretrainedConfig,
40
    is_torch_available,
41
    logging,
42
)
43
from transformers.testing_utils import (
44
    TOKEN,
45
    USER,
46
    CaptureLogger,
47
    LoggingLevel,
48
    TestCasePlus,
49
    is_staging_test,
50
    require_accelerate,
51
    require_flax,
52
    require_safetensors,
53
    require_tf,
54
    require_torch,
55
    require_torch_accelerator,
56
    require_torch_gpu,
57
    require_torch_multi_accelerator,
58
    require_usr_bin_time,
59
    slow,
60
    torch_device,
61
)
62
from transformers.utils import (
63
    SAFE_WEIGHTS_INDEX_NAME,
64
    SAFE_WEIGHTS_NAME,
65
    WEIGHTS_INDEX_NAME,
66
    WEIGHTS_NAME,
67
)
68
from transformers.utils.import_utils import (
69
    is_flash_attn_2_available,
70
    is_flax_available,
71
    is_tf_available,
72
    is_torch_sdpa_available,
73
    is_torchdynamo_available,
74
)
75

76

77
sys.path.append(str(Path(__file__).parent.parent / "utils"))
78

79
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig  # noqa E402
80

81

82
if is_torch_available():
83
    import torch
84
    from safetensors.torch import save_file as safe_save_file
85
    from test_module.custom_modeling import CustomModel, NoSuperInitModel
86
    from torch import nn
87

88
    from transformers import (
89
        BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
90
        AutoModelForCausalLM,
91
        AutoTokenizer,
92
        BertConfig,
93
        BertModel,
94
        CLIPTextModel,
95
        PreTrainedModel,
96
        T5Config,
97
        T5ForConditionalGeneration,
98
    )
99
    from transformers.modeling_attn_mask_utils import (
100
        AttentionMaskConverter,
101
        _create_4d_causal_attention_mask,
102
        _prepare_4d_attention_mask,
103
        _prepare_4d_causal_attention_mask,
104
    )
105
    from transformers.modeling_utils import shard_checkpoint
106

107
    # Fake pretrained models for tests
108
    class BaseModel(PreTrainedModel):
109
        base_model_prefix = "base"
110
        config_class = PretrainedConfig
111

112
        def __init__(self, config):
113
            super().__init__(config)
114
            self.linear = nn.Linear(5, 5)
115
            self.linear_2 = nn.Linear(5, 5)
116

117
        def forward(self, x):
118
            return self.linear_2(self.linear(x))
119

120
    class BaseModelWithTiedWeights(PreTrainedModel):
121
        config_class = PretrainedConfig
122

123
        def __init__(self, config):
124
            super().__init__(config)
125
            self.linear = nn.Linear(5, 5)
126
            self.linear_2 = nn.Linear(5, 5)
127

128
        def forward(self, x):
129
            return self.linear_2(self.linear(x))
130

131
        def tie_weights(self):
132
            self.linear_2.weight = self.linear.weight
133

134
    class ModelWithHead(PreTrainedModel):
135
        base_model_prefix = "base"
136
        config_class = PretrainedConfig
137

138
        def _init_weights(self, module):
139
            pass
140

141
        def __init__(self, config):
142
            super().__init__(config)
143
            self.base = BaseModel(config)
144
            # linear is a common name between Base and Head on purpose.
145
            self.linear = nn.Linear(5, 5)
146
            self.linear2 = nn.Linear(5, 5)
147

148
        def forward(self, x):
149
            return self.linear2(self.linear(self.base(x)))
150

151
    class ModelWithHeadAndTiedWeights(PreTrainedModel):
152
        base_model_prefix = "base"
153
        config_class = PretrainedConfig
154

155
        def _init_weights(self, module):
156
            pass
157

158
        def __init__(self, config):
159
            super().__init__(config)
160
            self.base = BaseModel(config)
161
            self.decoder = nn.Linear(5, 5)
162

163
        def forward(self, x):
164
            return self.decoder(self.base(x))
165

166
        def tie_weights(self):
167
            self.decoder.weight = self.base.linear.weight
168

169
    class Prepare4dCausalAttentionMaskModel(nn.Module):
170
        def forward(self, inputs_embeds):
171
            batch_size, seq_length, _ = inputs_embeds.shape
172
            past_key_values_length = 4
173
            attention_mask = _prepare_4d_causal_attention_mask(
174
                None, (batch_size, seq_length), inputs_embeds, past_key_values_length
175
            )
176
            return attention_mask
177

178
    class Create4dCausalAttentionMaskModel(nn.Module):
179
        def forward(self, inputs_embeds):
180
            batch_size, seq_length, _ = inputs_embeds.shape
181
            past_key_values_length = 4
182
            attention_mask = _create_4d_causal_attention_mask(
183
                (batch_size, seq_length),
184
                dtype=inputs_embeds.dtype,
185
                device=inputs_embeds.device,
186
                past_key_values_length=past_key_values_length,
187
            )
188
            return attention_mask
189

190
    class Prepare4dAttentionMaskModel(nn.Module):
191
        def forward(self, mask, inputs_embeds):
192
            attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
193
            return attention_mask
194

195

196
if is_flax_available():
197
    from transformers import FlaxBertModel
198

199
if is_tf_available():
200
    from transformers import TFBertModel
201

202

203
TINY_T5 = "patrickvonplaten/t5-tiny-random"
204
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
205
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
206

207

208
def check_models_equal(model1, model2):
209
    models_are_equal = True
210
    for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
211
        if model1_p.data.ne(model2_p.data).sum() > 0:
212
            models_are_equal = False
213

214
    return models_are_equal
215

216

217
@require_torch
218
class ModelUtilsTest(TestCasePlus):
219
    @slow
220
    def test_model_from_pretrained(self):
221
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
222
            config = BertConfig.from_pretrained(model_name)
223
            self.assertIsNotNone(config)
224
            self.assertIsInstance(config, PretrainedConfig)
225

226
            model = BertModel.from_pretrained(model_name)
227
            model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
228
            self.assertIsNotNone(model)
229
            self.assertIsInstance(model, PreTrainedModel)
230

231
            self.assertEqual(len(loading_info["missing_keys"]), 0)
232
            self.assertEqual(len(loading_info["unexpected_keys"]), 8)
233
            self.assertEqual(len(loading_info["mismatched_keys"]), 0)
234
            self.assertEqual(len(loading_info["error_msgs"]), 0)
235

236
            config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
237

238
            # Not sure this is the intended behavior. TODO fix Lysandre & Thom
239
            config.name_or_path = model_name
240

241
            model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
242
            self.assertEqual(model.config.output_hidden_states, True)
243
            self.assertEqual(model.config, config)
244

245
    def test_model_from_pretrained_subfolder(self):
246
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
247
        model = BertModel(config)
248

249
        subfolder = "bert"
250
        with tempfile.TemporaryDirectory() as tmp_dir:
251
            model.save_pretrained(os.path.join(tmp_dir, subfolder))
252

253
            with self.assertRaises(OSError):
254
                _ = BertModel.from_pretrained(tmp_dir)
255

256
            model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
257

258
        self.assertTrue(check_models_equal(model, model_loaded))
259

260
    def test_model_from_pretrained_subfolder_sharded(self):
261
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
262
        model = BertModel(config)
263

264
        subfolder = "bert"
265
        with tempfile.TemporaryDirectory() as tmp_dir:
266
            model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
267

268
            with self.assertRaises(OSError):
269
                _ = BertModel.from_pretrained(tmp_dir)
270

271
            model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
272

273
        self.assertTrue(check_models_equal(model, model_loaded))
274

275
    def test_model_from_pretrained_hub_subfolder(self):
276
        subfolder = "bert"
277
        model_id = "hf-internal-testing/tiny-random-bert-subfolder"
278
        with self.assertRaises(OSError):
279
            _ = BertModel.from_pretrained(model_id)
280

281
        model = BertModel.from_pretrained(model_id, subfolder=subfolder)
282

283
        self.assertIsNotNone(model)
284

285
    def test_model_from_pretrained_hub_subfolder_sharded(self):
286
        subfolder = "bert"
287
        model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
288
        with self.assertRaises(OSError):
289
            _ = BertModel.from_pretrained(model_id)
290

291
        model = BertModel.from_pretrained(model_id, subfolder=subfolder)
292

293
        self.assertIsNotNone(model)
294

295
    def test_model_from_pretrained_with_different_pretrained_model_name(self):
296
        model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
297
        self.assertIsNotNone(model)
298

299
        logger = logging.get_logger("transformers.configuration_utils")
300
        with LoggingLevel(logging.WARNING):
301
            with CaptureLogger(logger) as cl:
302
                BertModel.from_pretrained(TINY_T5)
303
        self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
304

305
    @require_accelerate
306
    def test_model_from_pretrained_with_none_quantization_config(self):
307
        # Needs a device_map for to enter the low_cpu_mem branch. We also load AutoModelForSequenceClassification
308
        # deliberately to enter the missing keys branch.
309
        model = AutoModelForSequenceClassification.from_pretrained(
310
            TINY_MISTRAL, device_map="auto", quantization_config=None
311
        )
312
        self.assertIsNotNone(model)
313

314
    def test_model_from_config_torch_dtype(self):
315
        # test that the model can be instantiated with dtype of user's choice - as long as it's a
316
        # float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
317
        # model from the config object.
318

319
        config = T5Config.from_pretrained(TINY_T5)
320
        model = AutoModel.from_config(config)
321
        # XXX: isn't supported
322
        # model = T5ForConditionalGeneration.from_config(config)
323
        self.assertEqual(model.dtype, torch.float32)
324

325
        model = AutoModel.from_config(config, torch_dtype=torch.float16)
326
        self.assertEqual(model.dtype, torch.float16)
327

328
        # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
329
        with self.assertRaises(ValueError):
330
            model = AutoModel.from_config(config, torch_dtype=torch.int64)
331

332
    def test_model_from_pretrained_torch_dtype(self):
333
        # test that the model can be instantiated with dtype of either
334
        # 1. explicit from_pretrained's torch_dtype argument
335
        # 2. via autodiscovery by looking at model weights (torch_dtype="auto")
336
        # so if a model.half() was saved, we want it to be instantiated as such.
337
        #
338
        # test an explicit model class, but also AutoModel separately as the latter goes through a different code path
339
        model_path = self.get_auto_remove_tmp_dir()
340

341
        # baseline - we know TINY_T5 is fp32 model
342
        model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
343
        self.assertEqual(model.dtype, torch.float32)
344

345
        def remove_torch_dtype(model_path):
346
            file = f"{model_path}/config.json"
347
            with open(file, "r", encoding="utf-8") as f:
348
                s = json.load(f)
349
            s.pop("torch_dtype")
350
            with open(file, "w", encoding="utf-8") as f:
351
                json.dump(s, f)
352

353
        # test the default fp32 save_pretrained => from_pretrained cycle
354
        model.save_pretrained(model_path)
355
        model = T5ForConditionalGeneration.from_pretrained(model_path)
356
        self.assertEqual(model.dtype, torch.float32)
357
        # 1. test torch_dtype="auto" via `config.torch_dtype`
358
        model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
359
        self.assertEqual(model.dtype, torch.float32)
360
        # 2. test torch_dtype="auto" via auto-derivation
361
        # now remove the torch_dtype entry from config.json and try "auto" again which should
362
        # perform auto-derivation from weights
363
        remove_torch_dtype(model_path)
364
        model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
365
        self.assertEqual(model.dtype, torch.float32)
366

367
        # test forced loading in fp16 (even though the weights are in fp32)
368
        model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
369
        self.assertEqual(model.dtype, torch.float16)
370

371
        # test fp16 save_pretrained, loaded with auto-detection
372
        model = model.half()
373
        model.save_pretrained(model_path)
374
        # 1. test torch_dtype="auto" via `config.torch_dtype`
375
        model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
376
        self.assertEqual(model.config.torch_dtype, torch.float16)
377
        self.assertEqual(model.dtype, torch.float16)
378
        # tests `config.torch_dtype` saving
379
        with open(f"{model_path}/config.json") as f:
380
            config_dict = json.load(f)
381
        self.assertEqual(config_dict["torch_dtype"], "float16")
382
        # 2. test torch_dtype="auto" via auto-derivation
383
        # now same with using config info
384
        remove_torch_dtype(model_path)
385
        model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
386
        self.assertEqual(model.dtype, torch.float16)
387

388
        # 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration
389
        model = AutoModel.from_pretrained(model_path, torch_dtype="auto")
390
        self.assertEqual(model.dtype, torch.float16)
391

392
        # test fp16 save_pretrained, loaded with the explicit fp16
393
        model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
394
        self.assertEqual(model.dtype, torch.float16)
395

396
        # test AutoModel separately as it goes through a different path
397
        # test auto-detection - as currently TINY_T5 doesn't have torch_dtype entry
398
        model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
399
        # test that the config object didn't get polluted with torch_dtype="auto"
400
        # there was a bug that after this call we ended up with config.torch_dtype=="auto"
401
        self.assertNotEqual(model.config.torch_dtype, "auto")
402
        # now test the outcome
403
        self.assertEqual(model.dtype, torch.float32)
404
        model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
405
        self.assertEqual(model.dtype, torch.float16)
406

407
        # test model whose first param is not of a floating type, but int
408
        model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
409
        self.assertEqual(model.dtype, torch.float32)
410

411
    def test_no_super_init_config_and_model(self):
412
        config = NoSuperInitConfig(attribute=32)
413
        model = NoSuperInitModel(config)
414

415
        with tempfile.TemporaryDirectory() as tmp_dir:
416
            model.save_pretrained(tmp_dir)
417

418
            new_model = NoSuperInitModel.from_pretrained(tmp_dir)
419

420
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
421
            self.assertTrue(torch.equal(p1, p2))
422

423
    def test_shard_checkpoint(self):
424
        # This is the model we will use, total size 340,000 bytes.
425
        model = torch.nn.Sequential(
426
            torch.nn.Linear(100, 200, bias=False),  # size 80,000
427
            torch.nn.Linear(200, 200, bias=False),  # size 160,000
428
            torch.nn.Linear(200, 100, bias=False),  # size 80,000
429
            torch.nn.Linear(100, 50, bias=False),  # size 20,000
430
        )
431
        state_dict = model.state_dict()
432

433
        with self.subTest("No shard when max size is bigger than model size"):
434
            shards, index = shard_checkpoint(state_dict)
435
            self.assertIsNone(index)
436
            self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
437

438
        with self.subTest("Test sharding, no weights bigger than max size"):
439
            shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
440
            # Split is first two layers then last two.
441
            self.assertDictEqual(
442
                index,
443
                {
444
                    "metadata": {"total_size": 340000},
445
                    "weight_map": {
446
                        "0.weight": "pytorch_model-00001-of-00002.bin",
447
                        "1.weight": "pytorch_model-00001-of-00002.bin",
448
                        "2.weight": "pytorch_model-00002-of-00002.bin",
449
                        "3.weight": "pytorch_model-00002-of-00002.bin",
450
                    },
451
                },
452
            )
453

454
            shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
455
            shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
456
            self.assertDictEqual(
457
                shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
458
            )
459

460
        with self.subTest("Test sharding with weights bigger than max size"):
461
            shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
462
            # Split is first layer, second layer then last 2.
463
            self.assertDictEqual(
464
                index,
465
                {
466
                    "metadata": {"total_size": 340000},
467
                    "weight_map": {
468
                        "0.weight": "pytorch_model-00001-of-00003.bin",
469
                        "1.weight": "pytorch_model-00002-of-00003.bin",
470
                        "2.weight": "pytorch_model-00003-of-00003.bin",
471
                        "3.weight": "pytorch_model-00003-of-00003.bin",
472
                    },
473
                },
474
            )
475

476
            shard1 = {"0.weight": state_dict["0.weight"]}
477
            shard2 = {"1.weight": state_dict["1.weight"]}
478
            shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
479
            self.assertDictEqual(
480
                shards,
481
                {
482
                    "pytorch_model-00001-of-00003.bin": shard1,
483
                    "pytorch_model-00002-of-00003.bin": shard2,
484
                    "pytorch_model-00003-of-00003.bin": shard3,
485
                },
486
            )
487

488
    def test_checkpoint_sharding_local_bin(self):
489
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
490

491
        with tempfile.TemporaryDirectory() as tmp_dir:
492
            # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
493
            for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
494
                model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
495

496
                # Get each shard file and its size
497
                shard_to_size = {}
498
                for shard in os.listdir(tmp_dir):
499
                    if shard.endswith(".bin"):
500
                        shard_file = os.path.join(tmp_dir, shard)
501
                        shard_to_size[shard_file] = os.path.getsize(shard_file)
502

503
                index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
504
                # Check there is an index but no regular weight file
505
                self.assertTrue(os.path.isfile(index_file))
506
                self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
507

508
                # Check a file is bigger than max_size only when it has a single weight
509
                for shard_file, size in shard_to_size.items():
510
                    if max_size.endswith("kiB"):
511
                        max_size_int = int(max_size[:-3]) * 2**10
512
                    else:
513
                        max_size_int = int(max_size[:-2]) * 10**3
514
                    # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
515
                    # the size asked for (since we count parameters)
516
                    if size >= max_size_int + 50000:
517
                        state_dict = torch.load(shard_file)
518
                        self.assertEqual(len(state_dict), 1)
519

520
                # Check the index and the shard files found match
521
                with open(index_file, "r", encoding="utf-8") as f:
522
                    index = json.loads(f.read())
523

524
                all_shards = set(index["weight_map"].values())
525
                shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".bin")}
526
                self.assertSetEqual(all_shards, shards_found)
527

528
                # Finally, check the model can be reloaded
529
                new_model = BertModel.from_pretrained(tmp_dir)
530
                for p1, p2 in zip(model.parameters(), new_model.parameters()):
531
                    self.assertTrue(torch.allclose(p1, p2))
532

533
    def test_checkpoint_sharding_from_hub(self):
534
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
535
        # the model above is the same as the model below, just a sharded version.
536
        ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
537
        for p1, p2 in zip(model.parameters(), ref_model.parameters()):
538
            self.assertTrue(torch.allclose(p1, p2))
539

540
    def test_checkpoint_variant_local_bin(self):
541
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
542

543
        with tempfile.TemporaryDirectory() as tmp_dir:
544
            model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
545

546
            weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
547

548
            weights_file = os.path.join(tmp_dir, weights_name)
549
            self.assertTrue(os.path.isfile(weights_file))
550
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
551

552
            with self.assertRaises(EnvironmentError):
553
                _ = BertModel.from_pretrained(tmp_dir)
554

555
            new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
556

557
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
558
            self.assertTrue(torch.allclose(p1, p2))
559

560
    def test_checkpoint_variant_local_sharded_bin(self):
561
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
562

563
        with tempfile.TemporaryDirectory() as tmp_dir:
564
            model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False)
565

566
            weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
567
            weights_index_file = os.path.join(tmp_dir, weights_index_name)
568
            self.assertTrue(os.path.isfile(weights_index_file))
569
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
570

571
            for i in range(1, 5):
572
                weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["bin"])
573
                weights_name_file = os.path.join(tmp_dir, weights_name)
574
                self.assertTrue(os.path.isfile(weights_name_file))
575

576
            with self.assertRaises(EnvironmentError):
577
                _ = BertModel.from_pretrained(tmp_dir)
578

579
            new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
580

581
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
582
            self.assertTrue(torch.allclose(p1, p2))
583

584
    @require_safetensors
585
    def test_checkpoint_variant_local_safe(self):
586
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
587

588
        with tempfile.TemporaryDirectory() as tmp_dir:
589
            model.save_pretrained(tmp_dir, variant="v2", safe_serialization=True)
590

591
            weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["safetensors"])
592

593
            weights_file = os.path.join(tmp_dir, weights_name)
594
            self.assertTrue(os.path.isfile(weights_file))
595
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
596

597
            with self.assertRaises(EnvironmentError):
598
                _ = BertModel.from_pretrained(tmp_dir)
599

600
            new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
601

602
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
603
            self.assertTrue(torch.allclose(p1, p2))
604

605
    @require_safetensors
606
    def test_checkpoint_variant_local_sharded_safe(self):
607
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
608

609
        with tempfile.TemporaryDirectory() as tmp_dir:
610
            model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=True)
611

612
            weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
613
            weights_index_file = os.path.join(tmp_dir, weights_index_name)
614
            self.assertTrue(os.path.isfile(weights_index_file))
615
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
616

617
            for i in range(1, 5):
618
                weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["safetensors"])
619
                weights_name_file = os.path.join(tmp_dir, weights_name)
620
                self.assertTrue(os.path.isfile(weights_name_file))
621

622
            with self.assertRaises(EnvironmentError):
623
                _ = BertModel.from_pretrained(tmp_dir)
624

625
            new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
626

627
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
628
            self.assertTrue(torch.allclose(p1, p2))
629

630
    def test_checkpoint_variant_hub(self):
631
        with tempfile.TemporaryDirectory() as tmp_dir:
632
            with self.assertRaises(EnvironmentError):
633
                _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir)
634
            model = BertModel.from_pretrained(
635
                "hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
636
            )
637
        self.assertIsNotNone(model)
638

639
    def test_checkpoint_variant_hub_sharded(self):
640
        with tempfile.TemporaryDirectory() as tmp_dir:
641
            with self.assertRaises(EnvironmentError):
642
                _ = BertModel.from_pretrained(
643
                    "hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir
644
                )
645
            model = BertModel.from_pretrained(
646
                "hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir, variant="v2"
647
            )
648
        self.assertIsNotNone(model)
649

650
    @require_safetensors
651
    def test_checkpoint_variant_hub_safe(self):
652
        with tempfile.TemporaryDirectory() as tmp_dir:
653
            with self.assertRaises(EnvironmentError):
654
                _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir)
655
            model = BertModel.from_pretrained(
656
                "hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir, variant="v2"
657
            )
658
        self.assertIsNotNone(model)
659

660
    @require_safetensors
661
    def test_checkpoint_variant_hub_sharded_safe(self):
662
        with tempfile.TemporaryDirectory() as tmp_dir:
663
            with self.assertRaises(EnvironmentError):
664
                _ = BertModel.from_pretrained(
665
                    "hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir
666
                )
667
            model = BertModel.from_pretrained(
668
                "hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir, variant="v2"
669
            )
670
        self.assertIsNotNone(model)
671

672
    def test_checkpoint_variant_save_load_bin(self):
673
        with tempfile.TemporaryDirectory() as tmp_dir:
674
            model = BertModel.from_pretrained(
675
                "hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
676
            )
677
            weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
678

679
            model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
680
            # saving will create a variant checkpoint
681
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
682

683
            model.save_pretrained(tmp_dir, safe_serialization=False)
684
            # saving shouldn't delete variant checkpoints
685
            weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
686
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
687

688
            # there should be a normal checkpoint
689
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
690

691
        self.assertIsNotNone(model)
692

693
    @require_accelerate
694
    @mark.accelerate_tests
695
    def test_from_pretrained_low_cpu_mem_usage_functional(self):
696
        # test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
697
        # sharded models
698

699
        mnames = [
700
            "hf-internal-testing/tiny-random-bert-sharded",
701
            "hf-internal-testing/tiny-random-bert",
702
        ]
703
        for mname in mnames:
704
            _ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
705

706
    @require_usr_bin_time
707
    @require_accelerate
708
    @mark.accelerate_tests
709
    def test_from_pretrained_low_cpu_mem_usage_measured(self):
710
        # test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
711

712
        mname = "google-bert/bert-base-cased"
713

714
        preamble = "from transformers import AutoModel"
715
        one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
716
        max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
717
        # print(f"{max_rss_normal=}")
718

719
        one_liner_str = f'{preamble};  AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
720
        max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
721
        # print(f"{max_rss_low_mem=}")
722

723
        diff_bytes = max_rss_normal - max_rss_low_mem
724
        diff_percent = diff_bytes / max_rss_low_mem
725
        # print(f"{diff_bytes=}, {diff_percent=}")
726
        # ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
727
        # measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
728
        # it's at least 15% less cpu memory consumed
729

730
        self.assertGreater(
731
            diff_percent,
732
            0.15,
733
            "should use less CPU memory for low_cpu_mem_usage=True, "
734
            f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
735
        )
736

737
        # if you want to compare things manually, let's first look at the size of the model in bytes
738
        # model = BertModel.from_pretrained(mname, low_cpu_mem_usage=False)
739
        # total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
740
        # total_bytes = total_numel * 4  # 420MB
741
        # Now the diff_bytes should be very close to total_bytes, but the reports are inconsistent.
742
        # The easiest way to test this is to switch the model and torch.load to do all the work on
743
        # gpu - that way one can measure exactly the total and peak memory used. Perhaps once we add
744
        # functionality to load models directly on gpu, this test can be rewritten to use torch's
745
        # cuda memory tracking and then we should be able to do a much more precise test.
746

747
    @require_accelerate
748
    @mark.accelerate_tests
749
    @require_torch_multi_accelerator
750
    @slow
751
    def test_model_parallelism_gpt2(self):
752
        device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
753
        for i in range(12):
754
            device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
755

756
        model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", device_map=device_map)
757

758
        tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
759
        inputs = tokenizer("Hello, my name is", return_tensors="pt")
760
        output = model.generate(inputs["input_ids"].to(0))
761

762
        text_output = tokenizer.decode(output[0].tolist())
763
        self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
764

765
    @require_accelerate
766
    @mark.accelerate_tests
767
    @require_torch_accelerator
768
    def test_from_pretrained_disk_offload_task_model(self):
769
        model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
770
        device_map = {
771
            "transformer.wte": 0,
772
            "transformer.wpe": 0,
773
            "transformer.h.0": "cpu",
774
            "transformer.h.1": "cpu",
775
            "transformer.h.2": "cpu",
776
            "transformer.h.3": "disk",
777
            "transformer.h.4": "disk",
778
            "transformer.ln_f": 0,
779
            "lm_head": 0,
780
        }
781
        with tempfile.TemporaryDirectory() as tmp_dir:
782
            inputs = torch.tensor([[1, 2, 3]]).to(0)
783

784
            model.save_pretrained(tmp_dir)
785
            new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(0)
786
            outputs1 = new_model.to(0)(inputs)
787

788
            offload_folder = os.path.join(tmp_dir, "offload")
789
            new_model_with_offload = AutoModelForCausalLM.from_pretrained(
790
                tmp_dir, device_map=device_map, offload_folder=offload_folder
791
            )
792
            outputs2 = new_model_with_offload(inputs)
793

794
            self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
795

796
            # With state dict temp offload
797
            offload_folder = os.path.join(tmp_dir, "offload")
798
            new_model_with_offload = AutoModelForCausalLM.from_pretrained(
799
                tmp_dir,
800
                device_map=device_map,
801
                offload_folder=offload_folder,
802
                offload_state_dict=True,
803
            )
804
            outputs2 = new_model_with_offload(inputs)
805

806
            self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
807

808
    @require_accelerate
809
    @mark.accelerate_tests
810
    @require_torch_accelerator
811
    def test_from_pretrained_disk_offload_derived_to_base_model(self):
812
        derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
813

814
        device_map = {
815
            "wte": 0,
816
            "wpe": 0,
817
            "h.0": "cpu",
818
            "h.1": "cpu",
819
            "h.2": "cpu",
820
            "h.3": "disk",
821
            "h.4": "disk",
822
            "ln_f": 0,
823
        }
824
        with tempfile.TemporaryDirectory() as tmp_dir:
825
            inputs = torch.tensor([[1, 2, 3]]).to(0)
826
            derived_model.save_pretrained(tmp_dir, use_safetensors=True)
827
            base_model = AutoModel.from_pretrained(tmp_dir)
828
            outputs1 = base_model.to(0)(inputs)
829

830
            # with disk offload
831
            offload_folder = os.path.join(tmp_dir, "offload")
832
            base_model_with_offload = AutoModel.from_pretrained(
833
                tmp_dir, device_map=device_map, offload_folder=offload_folder
834
            )
835
            outputs2 = base_model_with_offload(inputs)
836
            self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
837

838
            # With state dict temp offload
839
            new_model_with_offload = AutoModel.from_pretrained(
840
                tmp_dir,
841
                device_map=device_map,
842
                offload_folder=offload_folder,
843
                offload_state_dict=True,
844
            )
845
            outputs2 = new_model_with_offload(inputs)
846
            self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
847

848
    @slow
849
    @require_torch
850
    def test_from_pretrained_non_contiguous_checkpoint(self):
851
        # See: https://github.com/huggingface/transformers/pull/28414
852
        # Tiny models on the Hub have contiguous weights, contrarily to google/owlvit
853
        model = OwlViTForObjectDetection.from_pretrained("fxmarty/owlvit-tiny-non-contiguous-weight")
854
        self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
855

856
        model = OwlViTForObjectDetection.from_pretrained(
857
            "fxmarty/owlvit-tiny-non-contiguous-weight", device_map="auto"
858
        )
859
        self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
860

861
        with tempfile.TemporaryDirectory() as tmp_dir:
862
            model.save_pretrained(tmp_dir, safe_serialization=False)
863
            model.save_pretrained(tmp_dir, safe_serialization=True)
864

865
    def test_cached_files_are_used_when_internet_is_down(self):
866
        # A mock response for an HTTP head request to emulate server down
867
        response_mock = mock.Mock()
868
        response_mock.status_code = 500
869
        response_mock.headers = {}
870
        response_mock.raise_for_status.side_effect = HTTPError
871
        response_mock.json.return_value = {}
872

873
        # Download this model to make sure it's in the cache.
874
        _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
875

876
        # Under the mock environment we get a 500 error when trying to reach the model.
877
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
878
            _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
879
            # This check we did call the fake head request
880
            mock_head.assert_called()
881

882
    def test_load_from_one_file(self):
883
        try:
884
            tmp_file = tempfile.mktemp()
885
            with open(tmp_file, "wb") as f:
886
                http_get(
887
                    "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", f
888
                )
889

890
            config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
891
            _ = BertModel.from_pretrained(tmp_file, config=config)
892
        finally:
893
            os.remove(tmp_file)
894

895
    def test_legacy_load_from_url(self):
896
        # This test is for deprecated behavior and can be removed in v5
897
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
898
        _ = BertModel.from_pretrained(
899
            "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
900
        )
901

902
    @require_safetensors
903
    def test_use_safetensors(self):
904
        # Should not raise anymore
905
        AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
906

907
        # test that error if only safetensors is available
908
        with self.assertRaises(OSError) as env_error:
909
            BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
910

911
        self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))
912

913
        # test that only safetensors if both available and use_safetensors=False
914
        with tempfile.TemporaryDirectory() as tmp_dir:
915
            CLIPTextModel.from_pretrained(
916
                "hf-internal-testing/diffusers-stable-diffusion-tiny-all",
917
                subfolder="text_encoder",
918
                use_safetensors=False,
919
                cache_dir=tmp_dir,
920
            )
921

922
            all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
923
            self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
924
            self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))
925

926
        # test that no safetensors if both available and use_safetensors=True
927
        with tempfile.TemporaryDirectory() as tmp_dir:
928
            CLIPTextModel.from_pretrained(
929
                "hf-internal-testing/diffusers-stable-diffusion-tiny-all",
930
                subfolder="text_encoder",
931
                use_safetensors=True,
932
                cache_dir=tmp_dir,
933
            )
934

935
            all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
936
            self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
937
            self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
938

939
    @require_safetensors
940
    def test_safetensors_save_and_load(self):
941
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
942
        with tempfile.TemporaryDirectory() as tmp_dir:
943
            model.save_pretrained(tmp_dir, safe_serialization=True)
944
            # No pytorch_model.bin file, only a model.safetensors
945
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
946
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
947

948
            new_model = BertModel.from_pretrained(tmp_dir)
949

950
            # Check models are equal
951
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
952
                self.assertTrue(torch.allclose(p1, p2))
953

954
    @require_safetensors
955
    def test_safetensors_load_from_hub(self):
956
        safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
957
        pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
958

959
        # Check models are equal
960
        for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
961
            self.assertTrue(torch.allclose(p1, p2))
962

963
    @require_safetensors
964
    def test_safetensors_save_and_load_sharded(self):
965
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
966
        with tempfile.TemporaryDirectory() as tmp_dir:
967
            model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
968
            # No pytorch_model.bin index file, only a model.safetensors index
969
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
970
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
971
            # No regular weights file
972
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
973
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
974

975
            new_model = BertModel.from_pretrained(tmp_dir)
976

977
            # Check models are equal
978
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
979
                self.assertTrue(torch.allclose(p1, p2))
980

981
    @require_safetensors
982
    def test_safetensors_load_from_hub_sharded(self):
983
        safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
984
        pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
985

986
        # Check models are equal
987
        for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
988
            self.assertTrue(torch.allclose(p1, p2))
989

990
    def test_base_model_to_head_model_load(self):
991
        base_model = BaseModel(PretrainedConfig())
992
        with tempfile.TemporaryDirectory() as tmp_dir:
993
            base_model.save_pretrained(tmp_dir, safe_serialization=False)
994

995
            # Can load a base model in a model with head
996
            model = ModelWithHead.from_pretrained(tmp_dir)
997
            for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
998
                self.assertTrue(torch.allclose(p1, p2))
999

1000
            # It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
1001
            base_state_dict = base_model.state_dict()
1002
            head_state_dict = model.state_dict()
1003
            base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
1004
            base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
1005
            safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
1006

1007
            with self.assertRaisesRegex(
1008
                ValueError, "The state dictionary of the model you are trying to load is corrupted."
1009
            ):
1010
                _ = ModelWithHead.from_pretrained(tmp_dir)
1011

1012
    def test_tied_weights_reload(self):
1013
        # Base
1014
        model = BaseModelWithTiedWeights(PretrainedConfig())
1015
        with tempfile.TemporaryDirectory() as tmp_dir:
1016
            model.save_pretrained(tmp_dir)
1017

1018
            new_model = BaseModelWithTiedWeights.from_pretrained(tmp_dir)
1019
            self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
1020

1021
            state_dict = model.state_dict()
1022
            # Remove tied weight from state_dict -> model should load with no complain of missing keys
1023
            del state_dict["linear_2.weight"]
1024
            torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
1025
            new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
1026
            self.assertListEqual(load_info["missing_keys"], [])
1027
            self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
1028

1029
            # With head
1030
            model.save_pretrained(tmp_dir)
1031
            new_model, load_info = ModelWithHeadAndTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
1032
            self.assertIs(new_model.base.linear.weight, new_model.decoder.weight)
1033
            # Should only complain about the missing bias
1034
            self.assertListEqual(load_info["missing_keys"], ["decoder.bias"])
1035

1036
    def test_unexpected_keys_warnings(self):
1037
        model = ModelWithHead(PretrainedConfig())
1038
        logger = logging.get_logger("transformers.modeling_utils")
1039
        with tempfile.TemporaryDirectory() as tmp_dir:
1040
            model.save_pretrained(tmp_dir)
1041

1042
            # Loading the model with a new class, we don't get a warning for unexpected weights, just an info
1043
            with LoggingLevel(logging.WARNING):
1044
                with CaptureLogger(logger) as cl:
1045
                    _, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
1046
            self.assertNotIn("were not used when initializing ModelWithHead", cl.out)
1047
            self.assertEqual(
1048
                set(loading_info["unexpected_keys"]),
1049
                {"linear.weight", "linear.bias", "linear2.weight", "linear2.bias"},
1050
            )
1051

1052
            # Loading the model with the same class, we do get a warning for unexpected weights
1053
            state_dict = model.state_dict()
1054
            state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
1055
            safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
1056
            with LoggingLevel(logging.WARNING):
1057
                with CaptureLogger(logger) as cl:
1058
                    _, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
1059
            self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
1060
            self.assertEqual(loading_info["unexpected_keys"], ["added_key"])
1061

1062
    def test_warn_if_padding_and_no_attention_mask(self):
1063
        logger = logging.get_logger("transformers.modeling_utils")
1064

1065
        with self.subTest("Ensure no warnings when pad_token_id is None."):
1066
            logger.warning_once.cache_clear()
1067
            with LoggingLevel(logging.WARNING):
1068
                with CaptureLogger(logger) as cl:
1069
                    config_no_pad_token = PretrainedConfig()
1070
                    config_no_pad_token.pad_token_id = None
1071
                    model = ModelWithHead(config_no_pad_token)
1072
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1073
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1074
            self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
1075

1076
        with self.subTest("Ensure no warnings when there is an attention_mask."):
1077
            logger.warning_once.cache_clear()
1078
            with LoggingLevel(logging.WARNING):
1079
                with CaptureLogger(logger) as cl:
1080
                    config = PretrainedConfig()
1081
                    config.pad_token_id = 0
1082
                    model = ModelWithHead(config)
1083
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1084
                    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
1085
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1086
            self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
1087

1088
        with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."):
1089
            logger.warning_once.cache_clear()
1090
            with LoggingLevel(logging.WARNING):
1091
                with CaptureLogger(logger) as cl:
1092
                    config = PretrainedConfig()
1093
                    config.pad_token_id = 0
1094
                    model = ModelWithHead(config)
1095
                    input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]])
1096
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1097
            self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
1098

1099
        with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."):
1100
            logger.warning_once.cache_clear()
1101
            with LoggingLevel(logging.WARNING):
1102
                with CaptureLogger(logger) as cl:
1103
                    config = PretrainedConfig()
1104
                    config.pad_token_id = 0
1105
                    model = ModelWithHead(config)
1106
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
1107
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1108
            self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
1109

1110
        with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."):
1111
            logger.warning_once.cache_clear()
1112
            with LoggingLevel(logging.WARNING):
1113
                with CaptureLogger(logger) as cl:
1114
                    config = PretrainedConfig()
1115
                    config.pad_token_id = 0
1116
                    model = ModelWithHead(config)
1117
                    input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1118
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1119
            self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
1120

1121
        with self.subTest("Ensure that the warning is shown at most once."):
1122
            logger.warning_once.cache_clear()
1123
            with LoggingLevel(logging.WARNING):
1124
                with CaptureLogger(logger) as cl:
1125
                    config = PretrainedConfig()
1126
                    config.pad_token_id = 0
1127
                    model = ModelWithHead(config)
1128
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1129
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1130
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1131
            self.assertEqual(cl.out.count("We strongly recommend passing in an `attention_mask`"), 1)
1132

1133
        with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."):
1134
            logger.warning_once.cache_clear()
1135
            with LoggingLevel(logging.WARNING):
1136
                with CaptureLogger(logger) as cl:
1137
                    config = PretrainedConfig()
1138
                    config.pad_token_id = 0
1139
                    config.bos_token_id = config.pad_token_id
1140
                    model = ModelWithHead(config)
1141
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1142
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1143
            self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
1144

1145
        if not is_torchdynamo_available():
1146
            return
1147
        with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
1148
            logger.warning_once.cache_clear()
1149
            from torch._dynamo import config, testing
1150

1151
            config = PretrainedConfig()
1152
            config.pad_token_id = 0
1153
            model = ModelWithHead(config)
1154
            input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
1155

1156
            def f(input_ids):
1157
                model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1158

1159
            compile_counter = testing.CompileCounter()
1160
            opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
1161
            opt_fn(input_ids)
1162
            self.assertEqual(compile_counter.frame_count, 0)
1163

1164
    @require_torch_accelerator
1165
    @slow
1166
    def test_pretrained_low_mem_new_config(self):
1167
        # Checking for 1 model(the same one which was described in the issue) .
1168
        model_ids = ["openai-community/gpt2"]
1169

1170
        for model_id in model_ids:
1171
            model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id)
1172
            model_config.n_layer = 48
1173
            model_config.n_head = 25
1174
            model_config.n_embd = 1600
1175
            model = AutoModelForCausalLM.from_pretrained(
1176
                pretrained_model_name_or_path=model_id,
1177
                config=model_config,
1178
                ignore_mismatched_sizes=True,
1179
                torch_dtype=torch.float16,
1180
                low_cpu_mem_usage=True,
1181
            )
1182
            model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id)
1183

1184
            self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
1185

1186
    def test_generation_config_is_loaded_with_model(self):
1187
        # Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
1188
        # `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
1189

1190
        # 1. Load without further parameters
1191
        model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config")
1192
        self.assertEqual(model.generation_config.transformers_version, "foo")
1193

1194
        # 2. Load with `device_map`
1195
        model = AutoModelForCausalLM.from_pretrained(
1196
            "joaogante/tiny-random-gpt2-with-generation-config", device_map="auto"
1197
        )
1198
        self.assertEqual(model.generation_config.transformers_version, "foo")
1199

1200
    @require_safetensors
1201
    def test_safetensors_torch_from_torch(self):
1202
        model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1203

1204
        with tempfile.TemporaryDirectory() as tmp_dir:
1205
            model.save_pretrained(tmp_dir, safe_serialization=True)
1206
            new_model = BertModel.from_pretrained(tmp_dir)
1207

1208
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
1209
            self.assertTrue(torch.equal(p1, p2))
1210

1211
    @require_safetensors
1212
    @require_flax
1213
    def test_safetensors_torch_from_flax(self):
1214
        hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1215
        model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
1216

1217
        with tempfile.TemporaryDirectory() as tmp_dir:
1218
            model.save_pretrained(tmp_dir, safe_serialization=True)
1219
            new_model = BertModel.from_pretrained(tmp_dir)
1220

1221
        for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
1222
            self.assertTrue(torch.equal(p1, p2))
1223

1224
    @require_tf
1225
    @require_safetensors
1226
    def test_safetensors_torch_from_tf(self):
1227
        hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1228
        model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
1229

1230
        with tempfile.TemporaryDirectory() as tmp_dir:
1231
            model.save_pretrained(tmp_dir, safe_serialization=True)
1232
            new_model = BertModel.from_pretrained(tmp_dir)
1233

1234
        for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
1235
            self.assertTrue(torch.equal(p1, p2))
1236

1237
    @require_safetensors
1238
    def test_safetensors_torch_from_torch_sharded(self):
1239
        model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1240

1241
        with tempfile.TemporaryDirectory() as tmp_dir:
1242
            model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
1243
            new_model = BertModel.from_pretrained(tmp_dir)
1244

1245
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
1246
            self.assertTrue(torch.equal(p1, p2))
1247

1248
    def test_modifying_model_config_causes_warning_saving_generation_config(self):
1249
        model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
1250
        model.config.top_k = 1
1251
        with tempfile.TemporaryDirectory() as tmp_dir:
1252
            with self.assertLogs("transformers.modeling_utils", level="WARNING") as logs:
1253
                model.save_pretrained(tmp_dir)
1254
            self.assertEqual(len(logs.output), 1)
1255
            self.assertIn("Your generation config was originally created from the model config", logs.output[0])
1256

1257

1258
@slow
1259
@require_torch
1260
class ModelOnTheFlyConversionTester(unittest.TestCase):
1261
    @classmethod
1262
    def setUpClass(cls):
1263
        cls.user = "huggingface-hub-ci"
1264
        cls.token = os.getenv("HUGGINGFACE_PRODUCTION_USER_TOKEN", None)
1265

1266
        if cls.token is None:
1267
            raise ValueError("Cannot run tests as secret isn't setup.")
1268

1269
        cls.api = HfApi(token=cls.token)
1270

1271
    def setUp(self) -> None:
1272
        self.repo_name = f"{self.user}/test-model-on-the-fly-{uuid.uuid4()}"
1273

1274
    def tearDown(self) -> None:
1275
        self.api.delete_repo(self.repo_name)
1276

1277
    def test_safetensors_on_the_fly_conversion(self):
1278
        config = BertConfig(
1279
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1280
        )
1281
        initial_model = BertModel(config)
1282

1283
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
1284
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
1285

1286
        with self.subTest("Initial and converted models are equal"):
1287
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1288
                self.assertTrue(torch.equal(p1, p2))
1289

1290
        with self.subTest("PR was open with the safetensors account"):
1291
            discussions = self.api.get_repo_discussions(self.repo_name)
1292
            discussion = next(discussions)
1293
            self.assertEqual(discussion.author, "SFconvertbot")
1294
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1295

1296
    def test_safetensors_on_the_fly_conversion_private(self):
1297
        config = BertConfig(
1298
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1299
        )
1300
        initial_model = BertModel(config)
1301

1302
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
1303
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1304

1305
        with self.subTest("Initial and converted models are equal"):
1306
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1307
                self.assertTrue(torch.equal(p1, p2))
1308

1309
        with self.subTest("PR was open with the safetensors account"):
1310
            discussions = self.api.get_repo_discussions(self.repo_name, token=self.token)
1311
            discussion = next(discussions)
1312
            self.assertEqual(discussion.author, self.user)
1313
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1314

1315
    def test_safetensors_on_the_fly_conversion_gated(self):
1316
        config = BertConfig(
1317
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1318
        )
1319
        initial_model = BertModel(config)
1320

1321
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
1322
        headers = {"Authorization": f"Bearer {self.token}"}
1323
        requests.put(
1324
            f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
1325
        )
1326
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1327

1328
        with self.subTest("Initial and converted models are equal"):
1329
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1330
                self.assertTrue(torch.equal(p1, p2))
1331

1332
        with self.subTest("PR was open with the safetensors account"):
1333
            discussions = self.api.get_repo_discussions(self.repo_name)
1334
            discussion = next(discussions)
1335
            self.assertEqual(discussion.author, "SFconvertbot")
1336
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1337

1338
    def test_safetensors_on_the_fly_sharded_conversion(self):
1339
        config = BertConfig(
1340
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1341
        )
1342
        initial_model = BertModel(config)
1343

1344
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb")
1345
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
1346

1347
        with self.subTest("Initial and converted models are equal"):
1348
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1349
                self.assertTrue(torch.equal(p1, p2))
1350

1351
        with self.subTest("PR was open with the safetensors account"):
1352
            discussions = self.api.get_repo_discussions(self.repo_name)
1353
            discussion = next(discussions)
1354
            self.assertEqual(discussion.author, "SFconvertbot")
1355
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1356

1357
    def test_safetensors_on_the_fly_sharded_conversion_private(self):
1358
        config = BertConfig(
1359
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1360
        )
1361
        initial_model = BertModel(config)
1362

1363
        initial_model.push_to_hub(
1364
            self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb", private=True
1365
        )
1366
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1367

1368
        with self.subTest("Initial and converted models are equal"):
1369
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1370
                self.assertTrue(torch.equal(p1, p2))
1371

1372
        with self.subTest("PR was open with the safetensors account"):
1373
            discussions = self.api.get_repo_discussions(self.repo_name)
1374
            discussion = next(discussions)
1375
            self.assertEqual(discussion.author, self.user)
1376
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1377

1378
    def test_safetensors_on_the_fly_sharded_conversion_gated(self):
1379
        config = BertConfig(
1380
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1381
        )
1382
        initial_model = BertModel(config)
1383

1384
        initial_model.push_to_hub(self.repo_name, token=self.token, max_shard_size="200kb", safe_serialization=False)
1385
        headers = {"Authorization": f"Bearer {self.token}"}
1386
        requests.put(
1387
            f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
1388
        )
1389
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1390

1391
        with self.subTest("Initial and converted models are equal"):
1392
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1393
                self.assertTrue(torch.equal(p1, p2))
1394

1395
        with self.subTest("PR was open with the safetensors account"):
1396
            discussions = self.api.get_repo_discussions(self.repo_name)
1397
            discussion = next(discussions)
1398
            self.assertEqual(discussion.author, "SFconvertbot")
1399
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1400

1401
    @unittest.skip("Edge case, should work once the Space is updated`")
1402
    def test_safetensors_on_the_fly_wrong_user_opened_pr(self):
1403
        config = BertConfig(
1404
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1405
        )
1406
        initial_model = BertModel(config)
1407

1408
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
1409
        BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1410

1411
        # This should have opened a PR with the user's account
1412
        with self.subTest("PR was open with the safetensors account"):
1413
            discussions = self.api.get_repo_discussions(self.repo_name)
1414
            discussion = next(discussions)
1415
            self.assertEqual(discussion.author, self.user)
1416
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1417

1418
        # We now switch the repo visibility to public
1419
        self.api.update_repo_visibility(self.repo_name, private=False)
1420

1421
        # We once again call from_pretrained, which should call the bot to open a PR
1422
        BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1423

1424
        with self.subTest("PR was open with the safetensors account"):
1425
            discussions = self.api.get_repo_discussions(self.repo_name)
1426

1427
            bot_opened_pr = None
1428
            bot_opened_pr_title = None
1429

1430
            for discussion in discussions:
1431
                if discussion.author == "SFconvertBot":
1432
                    bot_opened_pr = True
1433
                    bot_opened_pr_title = discussion.title
1434

1435
            self.assertTrue(bot_opened_pr)
1436
            self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")
1437

1438
    def test_safetensors_on_the_fly_specific_revision(self):
1439
        config = BertConfig(
1440
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1441
        )
1442
        initial_model = BertModel(config)
1443

1444
        # Push a model on `main`
1445
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
1446

1447
        # Push a model on a given revision
1448
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, revision="new-branch")
1449

1450
        # Try to convert the model on that revision should raise
1451
        with self.assertRaises(EnvironmentError):
1452
            BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")
1453

1454

1455
@require_torch
1456
@is_staging_test
1457
class ModelPushToHubTester(unittest.TestCase):
1458
    @classmethod
1459
    def setUpClass(cls):
1460
        cls._token = TOKEN
1461
        HfFolder.save_token(TOKEN)
1462

1463
    @classmethod
1464
    def tearDownClass(cls):
1465
        try:
1466
            delete_repo(token=cls._token, repo_id="test-model")
1467
        except HTTPError:
1468
            pass
1469

1470
        try:
1471
            delete_repo(token=cls._token, repo_id="valid_org/test-model-org")
1472
        except HTTPError:
1473
            pass
1474

1475
        try:
1476
            delete_repo(token=cls._token, repo_id="test-dynamic-model")
1477
        except HTTPError:
1478
            pass
1479

1480
        try:
1481
            delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
1482
        except HTTPError:
1483
            pass
1484

1485
    @unittest.skip("This test is flaky")
1486
    def test_push_to_hub(self):
1487
        config = BertConfig(
1488
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1489
        )
1490
        model = BertModel(config)
1491
        model.push_to_hub("test-model", token=self._token)
1492

1493
        new_model = BertModel.from_pretrained(f"{USER}/test-model")
1494
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
1495
            self.assertTrue(torch.equal(p1, p2))
1496

1497
        # Reset repo
1498
        delete_repo(token=self._token, repo_id="test-model")
1499

1500
        # Push to hub via save_pretrained
1501
        with tempfile.TemporaryDirectory() as tmp_dir:
1502
            model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, token=self._token)
1503

1504
        new_model = BertModel.from_pretrained(f"{USER}/test-model")
1505
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
1506
            self.assertTrue(torch.equal(p1, p2))
1507

1508
    def test_push_to_hub_with_description(self):
1509
        config = BertConfig(
1510
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1511
        )
1512
        model = BertModel(config)
1513
        COMMIT_DESCRIPTION = """
1514
The commit description supports markdown synthax see:
1515
```python
1516
>>> form transformers import AutoConfig
1517
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
1518
```
1519
"""
1520
        commit_details = model.push_to_hub(
1521
            "test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION
1522
        )
1523
        self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)
1524

1525
    @unittest.skip("This test is flaky")
1526
    def test_push_to_hub_in_organization(self):
1527
        config = BertConfig(
1528
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1529
        )
1530
        model = BertModel(config)
1531
        model.push_to_hub("valid_org/test-model-org", token=self._token)
1532

1533
        new_model = BertModel.from_pretrained("valid_org/test-model-org")
1534
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
1535
            self.assertTrue(torch.equal(p1, p2))
1536

1537
        # Reset repo
1538
        delete_repo(token=self._token, repo_id="valid_org/test-model-org")
1539

1540
        # Push to hub via save_pretrained
1541
        with tempfile.TemporaryDirectory() as tmp_dir:
1542
            model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-org")
1543

1544
        new_model = BertModel.from_pretrained("valid_org/test-model-org")
1545
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
1546
            self.assertTrue(torch.equal(p1, p2))
1547

1548
    def test_push_to_hub_dynamic_model(self):
1549
        CustomConfig.register_for_auto_class()
1550
        CustomModel.register_for_auto_class()
1551

1552
        config = CustomConfig(hidden_size=32)
1553
        model = CustomModel(config)
1554

1555
        model.push_to_hub("test-dynamic-model", token=self._token)
1556
        # checks
1557
        self.assertDictEqual(
1558
            config.auto_map,
1559
            {"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
1560
        )
1561

1562
        new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
1563
        # Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
1564
        self.assertEqual(new_model.__class__.__name__, "CustomModel")
1565
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
1566
            self.assertTrue(torch.equal(p1, p2))
1567

1568
        config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
1569
        new_model = AutoModel.from_config(config, trust_remote_code=True)
1570
        self.assertEqual(new_model.__class__.__name__, "CustomModel")
1571

1572
    def test_push_to_hub_with_tags(self):
1573
        from huggingface_hub import ModelCard
1574

1575
        new_tags = ["tag-1", "tag-2"]
1576

1577
        CustomConfig.register_for_auto_class()
1578
        CustomModel.register_for_auto_class()
1579

1580
        config = CustomConfig(hidden_size=32)
1581
        model = CustomModel(config)
1582

1583
        self.assertTrue(model.model_tags is None)
1584

1585
        model.add_model_tags(new_tags)
1586

1587
        self.assertTrue(model.model_tags == new_tags)
1588

1589
        model.push_to_hub("test-dynamic-model-with-tags", token=self._token)
1590

1591
        loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
1592
        self.assertEqual(loaded_model_card.data.tags, new_tags)
1593

1594

1595
@require_torch
1596
class AttentionMaskTester(unittest.TestCase):
1597
    def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
1598
        mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
1599
        mask_4d_values = mask_4d[:, 0][mask_indices]
1600
        is_inf = mask_4d_values == -float("inf")
1601
        is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
1602
        assert torch.logical_or(is_inf, is_min).all()
1603

1604
    def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
1605
        mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
1606

1607
        if additional_mask is not None:
1608
            for bsz_idx, seq_idx in additional_mask:
1609
                mask_2d[bsz_idx, seq_idx] = 0
1610

1611
        mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32)
1612

1613
        assert mask_4d.shape == (bsz, 1, q_len, kv_len)
1614

1615
        # make sure there are no overflows
1616
        assert mask_4d.min() != float("-inf")
1617

1618
        context = mask_converter.sliding_window
1619
        if mask_converter.is_causal and context is None:
1620
            # k * (k+1) / 2 tokens are masked in triangualar masks
1621
            num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
1622

1623
            if 0 not in mask_2d:
1624
                assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1625
            if 0 in mask_2d:
1626
                # at least causal mask + maybe more
1627
                assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
1628
                self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
1629
        elif not mask_converter.is_causal and context is None:
1630
            if 0 not in mask_2d:
1631
                assert (mask_4d != 0).sum().cpu().item() == 0
1632
            if 0 in mask_2d:
1633
                self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
1634
        elif mask_converter.is_causal and context is not None:
1635
            # k * (k+1) / 2 tokens are masked in triangualar masks
1636
            num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
1637
            num_tokens_masked = bsz * num_tokens_masked
1638

1639
            if 0 not in mask_2d:
1640
                assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1641
            if 0 in mask_2d:
1642
                # at least causal mask + maybe more
1643
                assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
1644
                self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
1645

1646
    def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
1647
        mask_4d = mask_converter.to_causal_4d(
1648
            bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32
1649
        )
1650

1651
        if q_len == 1 and mask_converter.sliding_window is None:
1652
            # no causal mask if q_len is 1
1653
            assert mask_4d is None
1654
            return
1655

1656
        context = mask_converter.sliding_window
1657
        if mask_converter.is_causal and context is None:
1658
            # k * (k+1) / 2 tokens are masked in triangualar masks
1659
            num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
1660

1661
            assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1662
        elif not mask_converter.is_causal and context is None:
1663
            assert (mask_4d != 0).sum().cpu().item() == 0
1664
        elif mask_converter.is_causal and context is not None:
1665
            # k * (k+1) / 2 tokens are masked in triangualar masks
1666
            num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
1667
            num_tokens_masked = bsz * num_tokens_masked
1668

1669
            assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1670

1671
    def compute_num_context_mask(self, kv_len, context, q_len):
1672
        # This function computes the # of attention tokens that are added for
1673
        # the sliding window
1674
        c_mask_len = kv_len - context
1675
        num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
1676
        cut_mask_len = max(c_mask_len - q_len, 0)
1677
        num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
1678
        return num_mask_triangle - num_cut_mask
1679

1680
    def test_2d_to_4d_causal(self):
1681
        mask_converter = AttentionMaskConverter(is_causal=True)
1682

1683
        # auto-regressive use case
1684
        self.check_to_4d(mask_converter, q_len=1, kv_len=7)
1685
        # special auto-regressive case
1686
        self.check_to_4d(mask_converter, q_len=3, kv_len=7)
1687
        # non auto-regressive case
1688
        self.check_to_4d(mask_converter, q_len=7, kv_len=7)
1689

1690
        # same with extra attention masks
1691
        self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1692
        self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1693
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1694

1695
        # check that the mask does not overflow on causal masked tokens
1696
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 0), (1, 0), (1, 1)])
1697

1698
    def test_2d_to_4d(self):
1699
        mask_converter = AttentionMaskConverter(is_causal=False)
1700

1701
        # non auto-regressive case
1702
        self.check_to_4d(mask_converter, q_len=7, kv_len=7)
1703

1704
        # same with extra attention masks
1705
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1706

1707
    def test_2d_to_4d_causal_sliding(self):
1708
        mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=5)
1709

1710
        # auto-regressive use case
1711
        self.check_to_4d(mask_converter, q_len=1, kv_len=7)
1712
        # special auto-regressive case
1713
        self.check_to_4d(mask_converter, q_len=3, kv_len=7)
1714
        # non auto-regressive case
1715
        self.check_to_4d(mask_converter, q_len=7, kv_len=7)
1716

1717
        # same with extra attention masks
1718
        self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1719
        self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1720
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1721

1722
    def test_causal_mask(self):
1723
        mask_converter = AttentionMaskConverter(is_causal=True)
1724

1725
        # auto-regressive use case
1726
        self.check_to_causal(mask_converter, q_len=1, kv_len=7)
1727
        # special auto-regressive case
1728
        self.check_to_causal(mask_converter, q_len=3, kv_len=7)
1729
        # non auto-regressive case
1730
        self.check_to_causal(mask_converter, q_len=7, kv_len=7)
1731

1732
    def test_causal_mask_sliding(self):
1733
        mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=3)
1734

1735
        # auto-regressive use case
1736
        self.check_to_causal(mask_converter, q_len=1, kv_len=7)
1737
        # special auto-regressive case
1738
        self.check_to_causal(mask_converter, q_len=3, kv_len=7)
1739
        # non auto-regressive case
1740
        self.check_to_causal(mask_converter, q_len=7, kv_len=7)
1741

1742
    def test_torch_compile_fullgraph(self):
1743
        model = Prepare4dCausalAttentionMaskModel()
1744

1745
        inputs_embeds = torch.rand([1, 3, 32])
1746
        res_non_compiled = model(inputs_embeds)
1747

1748
        compiled_model = torch.compile(model, fullgraph=True)
1749

1750
        res_compiled = compiled_model(inputs_embeds)
1751

1752
        self.assertTrue(torch.equal(res_non_compiled, res_compiled))
1753

1754
        model = Create4dCausalAttentionMaskModel()
1755

1756
        inputs_embeds = torch.rand(2, 4, 16)
1757
        res_non_compiled = model(inputs_embeds)
1758

1759
        compiled_model = torch.compile(model, fullgraph=True)
1760
        res_compiled = compiled_model(inputs_embeds)
1761

1762
        self.assertTrue(torch.equal(res_non_compiled, res_compiled))
1763

1764
        model = Prepare4dAttentionMaskModel()
1765

1766
        mask = torch.ones(2, 4)
1767
        mask[0, :2] = 0
1768
        inputs_embeds = torch.rand(2, 4, 16)
1769

1770
        res_non_compiled = model(mask, inputs_embeds)
1771

1772
        compiled_model = torch.compile(model, fullgraph=True)
1773
        res_compiled = compiled_model(mask, inputs_embeds)
1774

1775
        self.assertTrue(torch.equal(res_non_compiled, res_compiled))
1776

1777
    @require_torch
1778
    @slow
1779
    def test_unmask_unattended_left_padding(self):
1780
        attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
1781

1782
        expanded_mask = torch.Tensor(
1783
            [
1784
                [[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
1785
                [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
1786
                [[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
1787
            ]
1788
        ).to(torch.int64)
1789

1790
        reference_output = torch.Tensor(
1791
            [
1792
                [[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
1793
                [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
1794
                [[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
1795
            ]
1796
        ).to(torch.int64)
1797

1798
        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
1799

1800
        self.assertTrue(torch.equal(result, reference_output))
1801

1802
        attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
1803

1804
        attn_mask_converter = AttentionMaskConverter(is_causal=True)
1805
        past_key_values_length = 0
1806
        key_value_length = attention_mask.shape[-1] + past_key_values_length
1807

1808
        expanded_mask = attn_mask_converter.to_4d(
1809
            attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
1810
        )
1811

1812
        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
1813
        min_inf = torch.finfo(torch.float32).min
1814
        reference_output = torch.Tensor(
1815
            [
1816
                [
1817
                    [
1818
                        [0, 0, 0, 0, 0],
1819
                        [0, 0, 0, 0, 0],
1820
                        [min_inf, min_inf, 0, min_inf, min_inf],
1821
                        [min_inf, min_inf, 0, 0, min_inf],
1822
                        [min_inf, min_inf, 0, 0, 0],
1823
                    ]
1824
                ],
1825
                [
1826
                    [
1827
                        [0, min_inf, min_inf, min_inf, min_inf],
1828
                        [0, 0, min_inf, min_inf, min_inf],
1829
                        [0, 0, 0, min_inf, min_inf],
1830
                        [0, 0, 0, 0, min_inf],
1831
                        [0, 0, 0, 0, 0],
1832
                    ]
1833
                ],
1834
                [
1835
                    [
1836
                        [0, 0, 0, 0, 0],
1837
                        [min_inf, 0, min_inf, min_inf, min_inf],
1838
                        [min_inf, 0, 0, min_inf, min_inf],
1839
                        [min_inf, 0, 0, 0, min_inf],
1840
                        [min_inf, 0, 0, 0, 0],
1841
                    ]
1842
                ],
1843
            ]
1844
        )
1845

1846
        self.assertTrue(torch.equal(reference_output, result))
1847

1848
    @require_torch
1849
    @slow
1850
    def test_unmask_unattended_right_padding(self):
1851
        attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
1852

1853
        attn_mask_converter = AttentionMaskConverter(is_causal=True)
1854
        past_key_values_length = 0
1855
        key_value_length = attention_mask.shape[-1] + past_key_values_length
1856

1857
        expanded_mask = attn_mask_converter.to_4d(
1858
            attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
1859
        )
1860

1861
        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
1862

1863
        self.assertTrue(torch.equal(expanded_mask, result))
1864

1865
    @require_torch
1866
    @slow
1867
    def test_unmask_unattended_random_mask(self):
1868
        attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
1869

1870
        attn_mask_converter = AttentionMaskConverter(is_causal=True)
1871
        past_key_values_length = 0
1872
        key_value_length = attention_mask.shape[-1] + past_key_values_length
1873

1874
        expanded_mask = attn_mask_converter.to_4d(
1875
            attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
1876
        )
1877

1878
        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
1879

1880
        self.assertTrue(torch.equal(expanded_mask, result))
1881

1882

1883
@require_torch
1884
class TestAttentionImplementation(unittest.TestCase):
1885
    def test_error_no_sdpa_available(self):
1886
        with self.assertRaises(ValueError) as cm:
1887
            _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
1888

1889
        self.assertTrue(
1890
            "does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
1891
            in str(cm.exception)
1892
        )
1893

1894
        _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
1895

1896
    def test_error_no_flash_available(self):
1897
        with self.assertRaises(ValueError) as cm:
1898
            _ = AutoModel.from_pretrained(
1899
                "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
1900
            )
1901

1902
        self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
1903

1904
    def test_error_no_flash_available_with_config(self):
1905
        with self.assertRaises(ValueError) as cm:
1906
            config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
1907

1908
            _ = AutoModel.from_pretrained(
1909
                "hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
1910
            )
1911

1912
        self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
1913

1914
    def test_error_wrong_attn_implementation(self):
1915
        with self.assertRaises(ValueError) as cm:
1916
            _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
1917

1918
        self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
1919

1920
    def test_not_available_flash(self):
1921
        if is_flash_attn_2_available():
1922
            self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
1923

1924
        with self.assertRaises(ImportError) as cm:
1925
            _ = AutoModel.from_pretrained(
1926
                "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
1927
            )
1928

1929
        self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
1930

1931
    def test_not_available_flash_with_config(self):
1932
        if is_flash_attn_2_available():
1933
            self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
1934

1935
        config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
1936

1937
        with self.assertRaises(ImportError) as cm:
1938
            _ = AutoModel.from_pretrained(
1939
                "hf-internal-testing/tiny-random-GPTBigCodeModel",
1940
                config=config,
1941
                attn_implementation="flash_attention_2",
1942
            )
1943

1944
        self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
1945

1946
    def test_not_available_sdpa(self):
1947
        if is_torch_sdpa_available():
1948
            self.skipTest("This test requires torch<=2.0")
1949

1950
        with self.assertRaises(ImportError) as cm:
1951
            _ = AutoModel.from_pretrained(
1952
                "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
1953
            )
1954

1955
        self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
1956

1957

1958
@slow
1959
@require_torch_gpu
1960
class Mask4DTestBase(unittest.TestCase):
1961
    def tearDown(self):
1962
        gc.collect()
1963
        torch.cuda.empty_cache()
1964

1965
    def get_test_data(self):
1966
        texts = ["the cat sat", "the cat had", "the cat is"]
1967
        encoded = [self.tokenizer.encode(t) for t in texts]
1968
        input_0 = torch.tensor(encoded, device=torch_device)
1969
        # tensor([[   1,  278, 6635, 3290],
1970
        # [   1,  278, 6635,  750],
1971
        # [   1,  278, 6635,  338]], device='cuda:0')
1972

1973
        # Combining common prefix with the unique ending tokens:
1974
        input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
1975
        # tensor([[   1,  278, 6635, 3290,  750,  338]], device='cuda:0')
1976

1977
        # Creating a 4D mask where each of the last 3 tokens do not attend to each other.
1978
        mask_1 = torch.tensor(
1979
            [
1980
                [
1981
                    [
1982
                        [1, 0, 0, 0, 0, 0],
1983
                        [1, 1, 0, 0, 0, 0],
1984
                        [1, 1, 1, 0, 0, 0],
1985
                        [1, 1, 1, 1, 0, 0],
1986
                        [1, 1, 1, 0, 1, 0],
1987
                        [1, 1, 1, 0, 0, 1],
1988
                    ]
1989
                ]
1990
            ],
1991
            device="cuda:0",
1992
            dtype=torch.int64,
1993
        )
1994

1995
        # Creating a position_ids tensor. note the repeating figures in the end.
1996
        position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
1997

1998
        return input_0, input_1, mask_1, position_ids_1
1999

2000

2001
@slow
2002
@require_torch_gpu
2003
class Mask4DTestFP32(Mask4DTestBase):
2004
    def setUp(self):
2005
        model_name = "JackFram/llama-68m"  # small Llama-like model from FlexFlow
2006
        model_dtype = torch.float32
2007
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
2008
        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
2009

2010
    def test_attention(self):
2011
        """comparing outputs of attention layer"""
2012
        input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2013

2014
        hid_0 = self.model.model.embed_tokens(input_0)
2015
        outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
2016
        # outs_0.shape == torch.Size([3, 4, 768])
2017

2018
        hid_1 = self.model.model.embed_tokens(input_1)
2019
        outs_1 = self.model.model.layers[0].self_attn.forward(
2020
            hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
2021
        )[0]
2022
        # outs_1.shape == torch.Size([1, 6, 768])
2023

2024
        outs_0_last_tokens = outs_0[:, -1, :]  # last tokens in each batch line
2025
        outs_1_last_tokens = outs_1[0, -3:, :]  # last three tokens
2026
        assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens)
2027

2028
    def test_inner_model(self):
2029
        """comparing hidden outputs of whole inner model"""
2030
        input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2031

2032
        logits_0 = self.model.forward(input_0).logits
2033
        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
2034

2035
        logits_0_last_tokens = logits_0[:, -1, :]  # last tokens in each batch line
2036
        logits_1_last_tokens = logits_1[0, -3:, :]  # last three tokens
2037
        torch.testing.assert_close(
2038
            logits_0_last_tokens,
2039
            logits_1_last_tokens,
2040
        )
2041

2042
    def test_causal_model_logits(self):
2043
        """comparing logits outputs of whole inner model"""
2044
        input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2045

2046
        logits_0 = self.model.forward(input_0).logits
2047
        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
2048

2049
        logits_0_last_tokens = logits_0[:, -1, :]  # last tokens in each batch line
2050
        logits_1_last_tokens = logits_1[0, -3:, :]  # last three tokens
2051
        torch.testing.assert_close(
2052
            logits_0_last_tokens,
2053
            logits_1_last_tokens,
2054
        )
2055

2056

2057
@slow
2058
@require_torch_gpu
2059
class Mask4DTestFP16(Mask4DTestBase):
2060
    test_attention = Mask4DTestFP32.test_attention
2061

2062
    def setUp(self):
2063
        model_name = "JackFram/llama-68m"  # small Llama-like model from FlexFlow
2064
        model_dtype = torch.float16
2065
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
2066
        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
2067

2068
    def test_causal_model_logits(self):
2069
        """comparing logits outputs of whole inner model"""
2070
        input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2071

2072
        logits_0 = self.model.forward(input_0).logits
2073
        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
2074

2075
        logits_0_last_tokens = logits_0[:, -1, :]  # last tokens in each batch line
2076
        logits_1_last_tokens = logits_1[0, -3:, :]  # last three tokens
2077

2078
        indices_0 = logits_0_last_tokens.sort(descending=True).indices
2079
        indices_1 = logits_1_last_tokens.sort(descending=True).indices
2080

2081
        # checking logits, but note relaxed tolerances for FP16
2082
        torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
2083

2084
        # checking tokens order for the top tokens
2085
        for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
2086
            self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
2087

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.