transformers
2086 строк · 91.2 Кб
1# coding=utf-8
2# Copyright 2019 HuggingFace Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import copy
16import gc
17import glob
18import json
19import os
20import os.path
21import sys
22import tempfile
23import unittest
24import unittest.mock as mock
25import uuid
26from pathlib import Path
27
28import requests
29from huggingface_hub import HfApi, HfFolder, delete_repo
30from huggingface_hub.file_download import http_get
31from pytest import mark
32from requests.exceptions import HTTPError
33
34from transformers import (
35AutoConfig,
36AutoModel,
37AutoModelForSequenceClassification,
38OwlViTForObjectDetection,
39PretrainedConfig,
40is_torch_available,
41logging,
42)
43from transformers.testing_utils import (
44TOKEN,
45USER,
46CaptureLogger,
47LoggingLevel,
48TestCasePlus,
49is_staging_test,
50require_accelerate,
51require_flax,
52require_safetensors,
53require_tf,
54require_torch,
55require_torch_accelerator,
56require_torch_gpu,
57require_torch_multi_accelerator,
58require_usr_bin_time,
59slow,
60torch_device,
61)
62from transformers.utils import (
63SAFE_WEIGHTS_INDEX_NAME,
64SAFE_WEIGHTS_NAME,
65WEIGHTS_INDEX_NAME,
66WEIGHTS_NAME,
67)
68from transformers.utils.import_utils import (
69is_flash_attn_2_available,
70is_flax_available,
71is_tf_available,
72is_torch_sdpa_available,
73is_torchdynamo_available,
74)
75
76
77sys.path.append(str(Path(__file__).parent.parent / "utils"))
78
79from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
80
81
82if is_torch_available():
83import torch
84from safetensors.torch import save_file as safe_save_file
85from test_module.custom_modeling import CustomModel, NoSuperInitModel
86from torch import nn
87
88from transformers import (
89BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
90AutoModelForCausalLM,
91AutoTokenizer,
92BertConfig,
93BertModel,
94CLIPTextModel,
95PreTrainedModel,
96T5Config,
97T5ForConditionalGeneration,
98)
99from transformers.modeling_attn_mask_utils import (
100AttentionMaskConverter,
101_create_4d_causal_attention_mask,
102_prepare_4d_attention_mask,
103_prepare_4d_causal_attention_mask,
104)
105from transformers.modeling_utils import shard_checkpoint
106
107# Fake pretrained models for tests
108class BaseModel(PreTrainedModel):
109base_model_prefix = "base"
110config_class = PretrainedConfig
111
112def __init__(self, config):
113super().__init__(config)
114self.linear = nn.Linear(5, 5)
115self.linear_2 = nn.Linear(5, 5)
116
117def forward(self, x):
118return self.linear_2(self.linear(x))
119
120class BaseModelWithTiedWeights(PreTrainedModel):
121config_class = PretrainedConfig
122
123def __init__(self, config):
124super().__init__(config)
125self.linear = nn.Linear(5, 5)
126self.linear_2 = nn.Linear(5, 5)
127
128def forward(self, x):
129return self.linear_2(self.linear(x))
130
131def tie_weights(self):
132self.linear_2.weight = self.linear.weight
133
134class ModelWithHead(PreTrainedModel):
135base_model_prefix = "base"
136config_class = PretrainedConfig
137
138def _init_weights(self, module):
139pass
140
141def __init__(self, config):
142super().__init__(config)
143self.base = BaseModel(config)
144# linear is a common name between Base and Head on purpose.
145self.linear = nn.Linear(5, 5)
146self.linear2 = nn.Linear(5, 5)
147
148def forward(self, x):
149return self.linear2(self.linear(self.base(x)))
150
151class ModelWithHeadAndTiedWeights(PreTrainedModel):
152base_model_prefix = "base"
153config_class = PretrainedConfig
154
155def _init_weights(self, module):
156pass
157
158def __init__(self, config):
159super().__init__(config)
160self.base = BaseModel(config)
161self.decoder = nn.Linear(5, 5)
162
163def forward(self, x):
164return self.decoder(self.base(x))
165
166def tie_weights(self):
167self.decoder.weight = self.base.linear.weight
168
169class Prepare4dCausalAttentionMaskModel(nn.Module):
170def forward(self, inputs_embeds):
171batch_size, seq_length, _ = inputs_embeds.shape
172past_key_values_length = 4
173attention_mask = _prepare_4d_causal_attention_mask(
174None, (batch_size, seq_length), inputs_embeds, past_key_values_length
175)
176return attention_mask
177
178class Create4dCausalAttentionMaskModel(nn.Module):
179def forward(self, inputs_embeds):
180batch_size, seq_length, _ = inputs_embeds.shape
181past_key_values_length = 4
182attention_mask = _create_4d_causal_attention_mask(
183(batch_size, seq_length),
184dtype=inputs_embeds.dtype,
185device=inputs_embeds.device,
186past_key_values_length=past_key_values_length,
187)
188return attention_mask
189
190class Prepare4dAttentionMaskModel(nn.Module):
191def forward(self, mask, inputs_embeds):
192attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
193return attention_mask
194
195
196if is_flax_available():
197from transformers import FlaxBertModel
198
199if is_tf_available():
200from transformers import TFBertModel
201
202
203TINY_T5 = "patrickvonplaten/t5-tiny-random"
204TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
205TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
206
207
208def check_models_equal(model1, model2):
209models_are_equal = True
210for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
211if model1_p.data.ne(model2_p.data).sum() > 0:
212models_are_equal = False
213
214return models_are_equal
215
216
217@require_torch
218class ModelUtilsTest(TestCasePlus):
219@slow
220def test_model_from_pretrained(self):
221for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
222config = BertConfig.from_pretrained(model_name)
223self.assertIsNotNone(config)
224self.assertIsInstance(config, PretrainedConfig)
225
226model = BertModel.from_pretrained(model_name)
227model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
228self.assertIsNotNone(model)
229self.assertIsInstance(model, PreTrainedModel)
230
231self.assertEqual(len(loading_info["missing_keys"]), 0)
232self.assertEqual(len(loading_info["unexpected_keys"]), 8)
233self.assertEqual(len(loading_info["mismatched_keys"]), 0)
234self.assertEqual(len(loading_info["error_msgs"]), 0)
235
236config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
237
238# Not sure this is the intended behavior. TODO fix Lysandre & Thom
239config.name_or_path = model_name
240
241model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
242self.assertEqual(model.config.output_hidden_states, True)
243self.assertEqual(model.config, config)
244
245def test_model_from_pretrained_subfolder(self):
246config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
247model = BertModel(config)
248
249subfolder = "bert"
250with tempfile.TemporaryDirectory() as tmp_dir:
251model.save_pretrained(os.path.join(tmp_dir, subfolder))
252
253with self.assertRaises(OSError):
254_ = BertModel.from_pretrained(tmp_dir)
255
256model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
257
258self.assertTrue(check_models_equal(model, model_loaded))
259
260def test_model_from_pretrained_subfolder_sharded(self):
261config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
262model = BertModel(config)
263
264subfolder = "bert"
265with tempfile.TemporaryDirectory() as tmp_dir:
266model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
267
268with self.assertRaises(OSError):
269_ = BertModel.from_pretrained(tmp_dir)
270
271model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
272
273self.assertTrue(check_models_equal(model, model_loaded))
274
275def test_model_from_pretrained_hub_subfolder(self):
276subfolder = "bert"
277model_id = "hf-internal-testing/tiny-random-bert-subfolder"
278with self.assertRaises(OSError):
279_ = BertModel.from_pretrained(model_id)
280
281model = BertModel.from_pretrained(model_id, subfolder=subfolder)
282
283self.assertIsNotNone(model)
284
285def test_model_from_pretrained_hub_subfolder_sharded(self):
286subfolder = "bert"
287model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
288with self.assertRaises(OSError):
289_ = BertModel.from_pretrained(model_id)
290
291model = BertModel.from_pretrained(model_id, subfolder=subfolder)
292
293self.assertIsNotNone(model)
294
295def test_model_from_pretrained_with_different_pretrained_model_name(self):
296model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
297self.assertIsNotNone(model)
298
299logger = logging.get_logger("transformers.configuration_utils")
300with LoggingLevel(logging.WARNING):
301with CaptureLogger(logger) as cl:
302BertModel.from_pretrained(TINY_T5)
303self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
304
305@require_accelerate
306def test_model_from_pretrained_with_none_quantization_config(self):
307# Needs a device_map for to enter the low_cpu_mem branch. We also load AutoModelForSequenceClassification
308# deliberately to enter the missing keys branch.
309model = AutoModelForSequenceClassification.from_pretrained(
310TINY_MISTRAL, device_map="auto", quantization_config=None
311)
312self.assertIsNotNone(model)
313
314def test_model_from_config_torch_dtype(self):
315# test that the model can be instantiated with dtype of user's choice - as long as it's a
316# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
317# model from the config object.
318
319config = T5Config.from_pretrained(TINY_T5)
320model = AutoModel.from_config(config)
321# XXX: isn't supported
322# model = T5ForConditionalGeneration.from_config(config)
323self.assertEqual(model.dtype, torch.float32)
324
325model = AutoModel.from_config(config, torch_dtype=torch.float16)
326self.assertEqual(model.dtype, torch.float16)
327
328# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
329with self.assertRaises(ValueError):
330model = AutoModel.from_config(config, torch_dtype=torch.int64)
331
332def test_model_from_pretrained_torch_dtype(self):
333# test that the model can be instantiated with dtype of either
334# 1. explicit from_pretrained's torch_dtype argument
335# 2. via autodiscovery by looking at model weights (torch_dtype="auto")
336# so if a model.half() was saved, we want it to be instantiated as such.
337#
338# test an explicit model class, but also AutoModel separately as the latter goes through a different code path
339model_path = self.get_auto_remove_tmp_dir()
340
341# baseline - we know TINY_T5 is fp32 model
342model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
343self.assertEqual(model.dtype, torch.float32)
344
345def remove_torch_dtype(model_path):
346file = f"{model_path}/config.json"
347with open(file, "r", encoding="utf-8") as f:
348s = json.load(f)
349s.pop("torch_dtype")
350with open(file, "w", encoding="utf-8") as f:
351json.dump(s, f)
352
353# test the default fp32 save_pretrained => from_pretrained cycle
354model.save_pretrained(model_path)
355model = T5ForConditionalGeneration.from_pretrained(model_path)
356self.assertEqual(model.dtype, torch.float32)
357# 1. test torch_dtype="auto" via `config.torch_dtype`
358model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
359self.assertEqual(model.dtype, torch.float32)
360# 2. test torch_dtype="auto" via auto-derivation
361# now remove the torch_dtype entry from config.json and try "auto" again which should
362# perform auto-derivation from weights
363remove_torch_dtype(model_path)
364model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
365self.assertEqual(model.dtype, torch.float32)
366
367# test forced loading in fp16 (even though the weights are in fp32)
368model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
369self.assertEqual(model.dtype, torch.float16)
370
371# test fp16 save_pretrained, loaded with auto-detection
372model = model.half()
373model.save_pretrained(model_path)
374# 1. test torch_dtype="auto" via `config.torch_dtype`
375model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
376self.assertEqual(model.config.torch_dtype, torch.float16)
377self.assertEqual(model.dtype, torch.float16)
378# tests `config.torch_dtype` saving
379with open(f"{model_path}/config.json") as f:
380config_dict = json.load(f)
381self.assertEqual(config_dict["torch_dtype"], "float16")
382# 2. test torch_dtype="auto" via auto-derivation
383# now same with using config info
384remove_torch_dtype(model_path)
385model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
386self.assertEqual(model.dtype, torch.float16)
387
388# 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration
389model = AutoModel.from_pretrained(model_path, torch_dtype="auto")
390self.assertEqual(model.dtype, torch.float16)
391
392# test fp16 save_pretrained, loaded with the explicit fp16
393model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
394self.assertEqual(model.dtype, torch.float16)
395
396# test AutoModel separately as it goes through a different path
397# test auto-detection - as currently TINY_T5 doesn't have torch_dtype entry
398model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
399# test that the config object didn't get polluted with torch_dtype="auto"
400# there was a bug that after this call we ended up with config.torch_dtype=="auto"
401self.assertNotEqual(model.config.torch_dtype, "auto")
402# now test the outcome
403self.assertEqual(model.dtype, torch.float32)
404model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
405self.assertEqual(model.dtype, torch.float16)
406
407# test model whose first param is not of a floating type, but int
408model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
409self.assertEqual(model.dtype, torch.float32)
410
411def test_no_super_init_config_and_model(self):
412config = NoSuperInitConfig(attribute=32)
413model = NoSuperInitModel(config)
414
415with tempfile.TemporaryDirectory() as tmp_dir:
416model.save_pretrained(tmp_dir)
417
418new_model = NoSuperInitModel.from_pretrained(tmp_dir)
419
420for p1, p2 in zip(model.parameters(), new_model.parameters()):
421self.assertTrue(torch.equal(p1, p2))
422
423def test_shard_checkpoint(self):
424# This is the model we will use, total size 340,000 bytes.
425model = torch.nn.Sequential(
426torch.nn.Linear(100, 200, bias=False), # size 80,000
427torch.nn.Linear(200, 200, bias=False), # size 160,000
428torch.nn.Linear(200, 100, bias=False), # size 80,000
429torch.nn.Linear(100, 50, bias=False), # size 20,000
430)
431state_dict = model.state_dict()
432
433with self.subTest("No shard when max size is bigger than model size"):
434shards, index = shard_checkpoint(state_dict)
435self.assertIsNone(index)
436self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
437
438with self.subTest("Test sharding, no weights bigger than max size"):
439shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
440# Split is first two layers then last two.
441self.assertDictEqual(
442index,
443{
444"metadata": {"total_size": 340000},
445"weight_map": {
446"0.weight": "pytorch_model-00001-of-00002.bin",
447"1.weight": "pytorch_model-00001-of-00002.bin",
448"2.weight": "pytorch_model-00002-of-00002.bin",
449"3.weight": "pytorch_model-00002-of-00002.bin",
450},
451},
452)
453
454shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
455shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
456self.assertDictEqual(
457shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
458)
459
460with self.subTest("Test sharding with weights bigger than max size"):
461shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
462# Split is first layer, second layer then last 2.
463self.assertDictEqual(
464index,
465{
466"metadata": {"total_size": 340000},
467"weight_map": {
468"0.weight": "pytorch_model-00001-of-00003.bin",
469"1.weight": "pytorch_model-00002-of-00003.bin",
470"2.weight": "pytorch_model-00003-of-00003.bin",
471"3.weight": "pytorch_model-00003-of-00003.bin",
472},
473},
474)
475
476shard1 = {"0.weight": state_dict["0.weight"]}
477shard2 = {"1.weight": state_dict["1.weight"]}
478shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
479self.assertDictEqual(
480shards,
481{
482"pytorch_model-00001-of-00003.bin": shard1,
483"pytorch_model-00002-of-00003.bin": shard2,
484"pytorch_model-00003-of-00003.bin": shard3,
485},
486)
487
488def test_checkpoint_sharding_local_bin(self):
489model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
490
491with tempfile.TemporaryDirectory() as tmp_dir:
492# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
493for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
494model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
495
496# Get each shard file and its size
497shard_to_size = {}
498for shard in os.listdir(tmp_dir):
499if shard.endswith(".bin"):
500shard_file = os.path.join(tmp_dir, shard)
501shard_to_size[shard_file] = os.path.getsize(shard_file)
502
503index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
504# Check there is an index but no regular weight file
505self.assertTrue(os.path.isfile(index_file))
506self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
507
508# Check a file is bigger than max_size only when it has a single weight
509for shard_file, size in shard_to_size.items():
510if max_size.endswith("kiB"):
511max_size_int = int(max_size[:-3]) * 2**10
512else:
513max_size_int = int(max_size[:-2]) * 10**3
514# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
515# the size asked for (since we count parameters)
516if size >= max_size_int + 50000:
517state_dict = torch.load(shard_file)
518self.assertEqual(len(state_dict), 1)
519
520# Check the index and the shard files found match
521with open(index_file, "r", encoding="utf-8") as f:
522index = json.loads(f.read())
523
524all_shards = set(index["weight_map"].values())
525shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".bin")}
526self.assertSetEqual(all_shards, shards_found)
527
528# Finally, check the model can be reloaded
529new_model = BertModel.from_pretrained(tmp_dir)
530for p1, p2 in zip(model.parameters(), new_model.parameters()):
531self.assertTrue(torch.allclose(p1, p2))
532
533def test_checkpoint_sharding_from_hub(self):
534model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
535# the model above is the same as the model below, just a sharded version.
536ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
537for p1, p2 in zip(model.parameters(), ref_model.parameters()):
538self.assertTrue(torch.allclose(p1, p2))
539
540def test_checkpoint_variant_local_bin(self):
541model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
542
543with tempfile.TemporaryDirectory() as tmp_dir:
544model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
545
546weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
547
548weights_file = os.path.join(tmp_dir, weights_name)
549self.assertTrue(os.path.isfile(weights_file))
550self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
551
552with self.assertRaises(EnvironmentError):
553_ = BertModel.from_pretrained(tmp_dir)
554
555new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
556
557for p1, p2 in zip(model.parameters(), new_model.parameters()):
558self.assertTrue(torch.allclose(p1, p2))
559
560def test_checkpoint_variant_local_sharded_bin(self):
561model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
562
563with tempfile.TemporaryDirectory() as tmp_dir:
564model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False)
565
566weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
567weights_index_file = os.path.join(tmp_dir, weights_index_name)
568self.assertTrue(os.path.isfile(weights_index_file))
569self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
570
571for i in range(1, 5):
572weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["bin"])
573weights_name_file = os.path.join(tmp_dir, weights_name)
574self.assertTrue(os.path.isfile(weights_name_file))
575
576with self.assertRaises(EnvironmentError):
577_ = BertModel.from_pretrained(tmp_dir)
578
579new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
580
581for p1, p2 in zip(model.parameters(), new_model.parameters()):
582self.assertTrue(torch.allclose(p1, p2))
583
584@require_safetensors
585def test_checkpoint_variant_local_safe(self):
586model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
587
588with tempfile.TemporaryDirectory() as tmp_dir:
589model.save_pretrained(tmp_dir, variant="v2", safe_serialization=True)
590
591weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["safetensors"])
592
593weights_file = os.path.join(tmp_dir, weights_name)
594self.assertTrue(os.path.isfile(weights_file))
595self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
596
597with self.assertRaises(EnvironmentError):
598_ = BertModel.from_pretrained(tmp_dir)
599
600new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
601
602for p1, p2 in zip(model.parameters(), new_model.parameters()):
603self.assertTrue(torch.allclose(p1, p2))
604
605@require_safetensors
606def test_checkpoint_variant_local_sharded_safe(self):
607model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
608
609with tempfile.TemporaryDirectory() as tmp_dir:
610model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=True)
611
612weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
613weights_index_file = os.path.join(tmp_dir, weights_index_name)
614self.assertTrue(os.path.isfile(weights_index_file))
615self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
616
617for i in range(1, 5):
618weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["safetensors"])
619weights_name_file = os.path.join(tmp_dir, weights_name)
620self.assertTrue(os.path.isfile(weights_name_file))
621
622with self.assertRaises(EnvironmentError):
623_ = BertModel.from_pretrained(tmp_dir)
624
625new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
626
627for p1, p2 in zip(model.parameters(), new_model.parameters()):
628self.assertTrue(torch.allclose(p1, p2))
629
630def test_checkpoint_variant_hub(self):
631with tempfile.TemporaryDirectory() as tmp_dir:
632with self.assertRaises(EnvironmentError):
633_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir)
634model = BertModel.from_pretrained(
635"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
636)
637self.assertIsNotNone(model)
638
639def test_checkpoint_variant_hub_sharded(self):
640with tempfile.TemporaryDirectory() as tmp_dir:
641with self.assertRaises(EnvironmentError):
642_ = BertModel.from_pretrained(
643"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir
644)
645model = BertModel.from_pretrained(
646"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir, variant="v2"
647)
648self.assertIsNotNone(model)
649
650@require_safetensors
651def test_checkpoint_variant_hub_safe(self):
652with tempfile.TemporaryDirectory() as tmp_dir:
653with self.assertRaises(EnvironmentError):
654_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir)
655model = BertModel.from_pretrained(
656"hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir, variant="v2"
657)
658self.assertIsNotNone(model)
659
660@require_safetensors
661def test_checkpoint_variant_hub_sharded_safe(self):
662with tempfile.TemporaryDirectory() as tmp_dir:
663with self.assertRaises(EnvironmentError):
664_ = BertModel.from_pretrained(
665"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir
666)
667model = BertModel.from_pretrained(
668"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir, variant="v2"
669)
670self.assertIsNotNone(model)
671
672def test_checkpoint_variant_save_load_bin(self):
673with tempfile.TemporaryDirectory() as tmp_dir:
674model = BertModel.from_pretrained(
675"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
676)
677weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
678
679model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
680# saving will create a variant checkpoint
681self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
682
683model.save_pretrained(tmp_dir, safe_serialization=False)
684# saving shouldn't delete variant checkpoints
685weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
686self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
687
688# there should be a normal checkpoint
689self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
690
691self.assertIsNotNone(model)
692
693@require_accelerate
694@mark.accelerate_tests
695def test_from_pretrained_low_cpu_mem_usage_functional(self):
696# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
697# sharded models
698
699mnames = [
700"hf-internal-testing/tiny-random-bert-sharded",
701"hf-internal-testing/tiny-random-bert",
702]
703for mname in mnames:
704_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
705
706@require_usr_bin_time
707@require_accelerate
708@mark.accelerate_tests
709def test_from_pretrained_low_cpu_mem_usage_measured(self):
710# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
711
712mname = "google-bert/bert-base-cased"
713
714preamble = "from transformers import AutoModel"
715one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
716max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
717# print(f"{max_rss_normal=}")
718
719one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
720max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
721# print(f"{max_rss_low_mem=}")
722
723diff_bytes = max_rss_normal - max_rss_low_mem
724diff_percent = diff_bytes / max_rss_low_mem
725# print(f"{diff_bytes=}, {diff_percent=}")
726# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
727# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
728# it's at least 15% less cpu memory consumed
729
730self.assertGreater(
731diff_percent,
7320.15,
733"should use less CPU memory for low_cpu_mem_usage=True, "
734f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
735)
736
737# if you want to compare things manually, let's first look at the size of the model in bytes
738# model = BertModel.from_pretrained(mname, low_cpu_mem_usage=False)
739# total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
740# total_bytes = total_numel * 4 # 420MB
741# Now the diff_bytes should be very close to total_bytes, but the reports are inconsistent.
742# The easiest way to test this is to switch the model and torch.load to do all the work on
743# gpu - that way one can measure exactly the total and peak memory used. Perhaps once we add
744# functionality to load models directly on gpu, this test can be rewritten to use torch's
745# cuda memory tracking and then we should be able to do a much more precise test.
746
747@require_accelerate
748@mark.accelerate_tests
749@require_torch_multi_accelerator
750@slow
751def test_model_parallelism_gpt2(self):
752device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
753for i in range(12):
754device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
755
756model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", device_map=device_map)
757
758tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
759inputs = tokenizer("Hello, my name is", return_tensors="pt")
760output = model.generate(inputs["input_ids"].to(0))
761
762text_output = tokenizer.decode(output[0].tolist())
763self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
764
765@require_accelerate
766@mark.accelerate_tests
767@require_torch_accelerator
768def test_from_pretrained_disk_offload_task_model(self):
769model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
770device_map = {
771"transformer.wte": 0,
772"transformer.wpe": 0,
773"transformer.h.0": "cpu",
774"transformer.h.1": "cpu",
775"transformer.h.2": "cpu",
776"transformer.h.3": "disk",
777"transformer.h.4": "disk",
778"transformer.ln_f": 0,
779"lm_head": 0,
780}
781with tempfile.TemporaryDirectory() as tmp_dir:
782inputs = torch.tensor([[1, 2, 3]]).to(0)
783
784model.save_pretrained(tmp_dir)
785new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(0)
786outputs1 = new_model.to(0)(inputs)
787
788offload_folder = os.path.join(tmp_dir, "offload")
789new_model_with_offload = AutoModelForCausalLM.from_pretrained(
790tmp_dir, device_map=device_map, offload_folder=offload_folder
791)
792outputs2 = new_model_with_offload(inputs)
793
794self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
795
796# With state dict temp offload
797offload_folder = os.path.join(tmp_dir, "offload")
798new_model_with_offload = AutoModelForCausalLM.from_pretrained(
799tmp_dir,
800device_map=device_map,
801offload_folder=offload_folder,
802offload_state_dict=True,
803)
804outputs2 = new_model_with_offload(inputs)
805
806self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
807
808@require_accelerate
809@mark.accelerate_tests
810@require_torch_accelerator
811def test_from_pretrained_disk_offload_derived_to_base_model(self):
812derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
813
814device_map = {
815"wte": 0,
816"wpe": 0,
817"h.0": "cpu",
818"h.1": "cpu",
819"h.2": "cpu",
820"h.3": "disk",
821"h.4": "disk",
822"ln_f": 0,
823}
824with tempfile.TemporaryDirectory() as tmp_dir:
825inputs = torch.tensor([[1, 2, 3]]).to(0)
826derived_model.save_pretrained(tmp_dir, use_safetensors=True)
827base_model = AutoModel.from_pretrained(tmp_dir)
828outputs1 = base_model.to(0)(inputs)
829
830# with disk offload
831offload_folder = os.path.join(tmp_dir, "offload")
832base_model_with_offload = AutoModel.from_pretrained(
833tmp_dir, device_map=device_map, offload_folder=offload_folder
834)
835outputs2 = base_model_with_offload(inputs)
836self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
837
838# With state dict temp offload
839new_model_with_offload = AutoModel.from_pretrained(
840tmp_dir,
841device_map=device_map,
842offload_folder=offload_folder,
843offload_state_dict=True,
844)
845outputs2 = new_model_with_offload(inputs)
846self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
847
848@slow
849@require_torch
850def test_from_pretrained_non_contiguous_checkpoint(self):
851# See: https://github.com/huggingface/transformers/pull/28414
852# Tiny models on the Hub have contiguous weights, contrarily to google/owlvit
853model = OwlViTForObjectDetection.from_pretrained("fxmarty/owlvit-tiny-non-contiguous-weight")
854self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
855
856model = OwlViTForObjectDetection.from_pretrained(
857"fxmarty/owlvit-tiny-non-contiguous-weight", device_map="auto"
858)
859self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
860
861with tempfile.TemporaryDirectory() as tmp_dir:
862model.save_pretrained(tmp_dir, safe_serialization=False)
863model.save_pretrained(tmp_dir, safe_serialization=True)
864
865def test_cached_files_are_used_when_internet_is_down(self):
866# A mock response for an HTTP head request to emulate server down
867response_mock = mock.Mock()
868response_mock.status_code = 500
869response_mock.headers = {}
870response_mock.raise_for_status.side_effect = HTTPError
871response_mock.json.return_value = {}
872
873# Download this model to make sure it's in the cache.
874_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
875
876# Under the mock environment we get a 500 error when trying to reach the model.
877with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
878_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
879# This check we did call the fake head request
880mock_head.assert_called()
881
882def test_load_from_one_file(self):
883try:
884tmp_file = tempfile.mktemp()
885with open(tmp_file, "wb") as f:
886http_get(
887"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", f
888)
889
890config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
891_ = BertModel.from_pretrained(tmp_file, config=config)
892finally:
893os.remove(tmp_file)
894
895def test_legacy_load_from_url(self):
896# This test is for deprecated behavior and can be removed in v5
897config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
898_ = BertModel.from_pretrained(
899"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
900)
901
902@require_safetensors
903def test_use_safetensors(self):
904# Should not raise anymore
905AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
906
907# test that error if only safetensors is available
908with self.assertRaises(OSError) as env_error:
909BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
910
911self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))
912
913# test that only safetensors if both available and use_safetensors=False
914with tempfile.TemporaryDirectory() as tmp_dir:
915CLIPTextModel.from_pretrained(
916"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
917subfolder="text_encoder",
918use_safetensors=False,
919cache_dir=tmp_dir,
920)
921
922all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
923self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
924self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))
925
926# test that no safetensors if both available and use_safetensors=True
927with tempfile.TemporaryDirectory() as tmp_dir:
928CLIPTextModel.from_pretrained(
929"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
930subfolder="text_encoder",
931use_safetensors=True,
932cache_dir=tmp_dir,
933)
934
935all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
936self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
937self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
938
939@require_safetensors
940def test_safetensors_save_and_load(self):
941model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
942with tempfile.TemporaryDirectory() as tmp_dir:
943model.save_pretrained(tmp_dir, safe_serialization=True)
944# No pytorch_model.bin file, only a model.safetensors
945self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
946self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
947
948new_model = BertModel.from_pretrained(tmp_dir)
949
950# Check models are equal
951for p1, p2 in zip(model.parameters(), new_model.parameters()):
952self.assertTrue(torch.allclose(p1, p2))
953
954@require_safetensors
955def test_safetensors_load_from_hub(self):
956safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
957pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
958
959# Check models are equal
960for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
961self.assertTrue(torch.allclose(p1, p2))
962
963@require_safetensors
964def test_safetensors_save_and_load_sharded(self):
965model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
966with tempfile.TemporaryDirectory() as tmp_dir:
967model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
968# No pytorch_model.bin index file, only a model.safetensors index
969self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
970self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
971# No regular weights file
972self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
973self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
974
975new_model = BertModel.from_pretrained(tmp_dir)
976
977# Check models are equal
978for p1, p2 in zip(model.parameters(), new_model.parameters()):
979self.assertTrue(torch.allclose(p1, p2))
980
981@require_safetensors
982def test_safetensors_load_from_hub_sharded(self):
983safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
984pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
985
986# Check models are equal
987for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
988self.assertTrue(torch.allclose(p1, p2))
989
990def test_base_model_to_head_model_load(self):
991base_model = BaseModel(PretrainedConfig())
992with tempfile.TemporaryDirectory() as tmp_dir:
993base_model.save_pretrained(tmp_dir, safe_serialization=False)
994
995# Can load a base model in a model with head
996model = ModelWithHead.from_pretrained(tmp_dir)
997for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
998self.assertTrue(torch.allclose(p1, p2))
999
1000# It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
1001base_state_dict = base_model.state_dict()
1002head_state_dict = model.state_dict()
1003base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
1004base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
1005safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
1006
1007with self.assertRaisesRegex(
1008ValueError, "The state dictionary of the model you are trying to load is corrupted."
1009):
1010_ = ModelWithHead.from_pretrained(tmp_dir)
1011
1012def test_tied_weights_reload(self):
1013# Base
1014model = BaseModelWithTiedWeights(PretrainedConfig())
1015with tempfile.TemporaryDirectory() as tmp_dir:
1016model.save_pretrained(tmp_dir)
1017
1018new_model = BaseModelWithTiedWeights.from_pretrained(tmp_dir)
1019self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
1020
1021state_dict = model.state_dict()
1022# Remove tied weight from state_dict -> model should load with no complain of missing keys
1023del state_dict["linear_2.weight"]
1024torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
1025new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
1026self.assertListEqual(load_info["missing_keys"], [])
1027self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
1028
1029# With head
1030model.save_pretrained(tmp_dir)
1031new_model, load_info = ModelWithHeadAndTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
1032self.assertIs(new_model.base.linear.weight, new_model.decoder.weight)
1033# Should only complain about the missing bias
1034self.assertListEqual(load_info["missing_keys"], ["decoder.bias"])
1035
1036def test_unexpected_keys_warnings(self):
1037model = ModelWithHead(PretrainedConfig())
1038logger = logging.get_logger("transformers.modeling_utils")
1039with tempfile.TemporaryDirectory() as tmp_dir:
1040model.save_pretrained(tmp_dir)
1041
1042# Loading the model with a new class, we don't get a warning for unexpected weights, just an info
1043with LoggingLevel(logging.WARNING):
1044with CaptureLogger(logger) as cl:
1045_, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
1046self.assertNotIn("were not used when initializing ModelWithHead", cl.out)
1047self.assertEqual(
1048set(loading_info["unexpected_keys"]),
1049{"linear.weight", "linear.bias", "linear2.weight", "linear2.bias"},
1050)
1051
1052# Loading the model with the same class, we do get a warning for unexpected weights
1053state_dict = model.state_dict()
1054state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
1055safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
1056with LoggingLevel(logging.WARNING):
1057with CaptureLogger(logger) as cl:
1058_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
1059self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
1060self.assertEqual(loading_info["unexpected_keys"], ["added_key"])
1061
1062def test_warn_if_padding_and_no_attention_mask(self):
1063logger = logging.get_logger("transformers.modeling_utils")
1064
1065with self.subTest("Ensure no warnings when pad_token_id is None."):
1066logger.warning_once.cache_clear()
1067with LoggingLevel(logging.WARNING):
1068with CaptureLogger(logger) as cl:
1069config_no_pad_token = PretrainedConfig()
1070config_no_pad_token.pad_token_id = None
1071model = ModelWithHead(config_no_pad_token)
1072input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1073model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1074self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
1075
1076with self.subTest("Ensure no warnings when there is an attention_mask."):
1077logger.warning_once.cache_clear()
1078with LoggingLevel(logging.WARNING):
1079with CaptureLogger(logger) as cl:
1080config = PretrainedConfig()
1081config.pad_token_id = 0
1082model = ModelWithHead(config)
1083input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1084attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
1085model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1086self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
1087
1088with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."):
1089logger.warning_once.cache_clear()
1090with LoggingLevel(logging.WARNING):
1091with CaptureLogger(logger) as cl:
1092config = PretrainedConfig()
1093config.pad_token_id = 0
1094model = ModelWithHead(config)
1095input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]])
1096model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1097self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)
1098
1099with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."):
1100logger.warning_once.cache_clear()
1101with LoggingLevel(logging.WARNING):
1102with CaptureLogger(logger) as cl:
1103config = PretrainedConfig()
1104config.pad_token_id = 0
1105model = ModelWithHead(config)
1106input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
1107model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1108self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
1109
1110with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."):
1111logger.warning_once.cache_clear()
1112with LoggingLevel(logging.WARNING):
1113with CaptureLogger(logger) as cl:
1114config = PretrainedConfig()
1115config.pad_token_id = 0
1116model = ModelWithHead(config)
1117input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1118model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1119self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
1120
1121with self.subTest("Ensure that the warning is shown at most once."):
1122logger.warning_once.cache_clear()
1123with LoggingLevel(logging.WARNING):
1124with CaptureLogger(logger) as cl:
1125config = PretrainedConfig()
1126config.pad_token_id = 0
1127model = ModelWithHead(config)
1128input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1129model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1130model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1131self.assertEqual(cl.out.count("We strongly recommend passing in an `attention_mask`"), 1)
1132
1133with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."):
1134logger.warning_once.cache_clear()
1135with LoggingLevel(logging.WARNING):
1136with CaptureLogger(logger) as cl:
1137config = PretrainedConfig()
1138config.pad_token_id = 0
1139config.bos_token_id = config.pad_token_id
1140model = ModelWithHead(config)
1141input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
1142model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1143self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
1144
1145if not is_torchdynamo_available():
1146return
1147with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
1148logger.warning_once.cache_clear()
1149from torch._dynamo import config, testing
1150
1151config = PretrainedConfig()
1152config.pad_token_id = 0
1153model = ModelWithHead(config)
1154input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
1155
1156def f(input_ids):
1157model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
1158
1159compile_counter = testing.CompileCounter()
1160opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
1161opt_fn(input_ids)
1162self.assertEqual(compile_counter.frame_count, 0)
1163
1164@require_torch_accelerator
1165@slow
1166def test_pretrained_low_mem_new_config(self):
1167# Checking for 1 model(the same one which was described in the issue) .
1168model_ids = ["openai-community/gpt2"]
1169
1170for model_id in model_ids:
1171model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id)
1172model_config.n_layer = 48
1173model_config.n_head = 25
1174model_config.n_embd = 1600
1175model = AutoModelForCausalLM.from_pretrained(
1176pretrained_model_name_or_path=model_id,
1177config=model_config,
1178ignore_mismatched_sizes=True,
1179torch_dtype=torch.float16,
1180low_cpu_mem_usage=True,
1181)
1182model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id)
1183
1184self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
1185
1186def test_generation_config_is_loaded_with_model(self):
1187# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
1188# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
1189
1190# 1. Load without further parameters
1191model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config")
1192self.assertEqual(model.generation_config.transformers_version, "foo")
1193
1194# 2. Load with `device_map`
1195model = AutoModelForCausalLM.from_pretrained(
1196"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto"
1197)
1198self.assertEqual(model.generation_config.transformers_version, "foo")
1199
1200@require_safetensors
1201def test_safetensors_torch_from_torch(self):
1202model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1203
1204with tempfile.TemporaryDirectory() as tmp_dir:
1205model.save_pretrained(tmp_dir, safe_serialization=True)
1206new_model = BertModel.from_pretrained(tmp_dir)
1207
1208for p1, p2 in zip(model.parameters(), new_model.parameters()):
1209self.assertTrue(torch.equal(p1, p2))
1210
1211@require_safetensors
1212@require_flax
1213def test_safetensors_torch_from_flax(self):
1214hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1215model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
1216
1217with tempfile.TemporaryDirectory() as tmp_dir:
1218model.save_pretrained(tmp_dir, safe_serialization=True)
1219new_model = BertModel.from_pretrained(tmp_dir)
1220
1221for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
1222self.assertTrue(torch.equal(p1, p2))
1223
1224@require_tf
1225@require_safetensors
1226def test_safetensors_torch_from_tf(self):
1227hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1228model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
1229
1230with tempfile.TemporaryDirectory() as tmp_dir:
1231model.save_pretrained(tmp_dir, safe_serialization=True)
1232new_model = BertModel.from_pretrained(tmp_dir)
1233
1234for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
1235self.assertTrue(torch.equal(p1, p2))
1236
1237@require_safetensors
1238def test_safetensors_torch_from_torch_sharded(self):
1239model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
1240
1241with tempfile.TemporaryDirectory() as tmp_dir:
1242model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
1243new_model = BertModel.from_pretrained(tmp_dir)
1244
1245for p1, p2 in zip(model.parameters(), new_model.parameters()):
1246self.assertTrue(torch.equal(p1, p2))
1247
1248def test_modifying_model_config_causes_warning_saving_generation_config(self):
1249model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
1250model.config.top_k = 1
1251with tempfile.TemporaryDirectory() as tmp_dir:
1252with self.assertLogs("transformers.modeling_utils", level="WARNING") as logs:
1253model.save_pretrained(tmp_dir)
1254self.assertEqual(len(logs.output), 1)
1255self.assertIn("Your generation config was originally created from the model config", logs.output[0])
1256
1257
1258@slow
1259@require_torch
1260class ModelOnTheFlyConversionTester(unittest.TestCase):
1261@classmethod
1262def setUpClass(cls):
1263cls.user = "huggingface-hub-ci"
1264cls.token = os.getenv("HUGGINGFACE_PRODUCTION_USER_TOKEN", None)
1265
1266if cls.token is None:
1267raise ValueError("Cannot run tests as secret isn't setup.")
1268
1269cls.api = HfApi(token=cls.token)
1270
1271def setUp(self) -> None:
1272self.repo_name = f"{self.user}/test-model-on-the-fly-{uuid.uuid4()}"
1273
1274def tearDown(self) -> None:
1275self.api.delete_repo(self.repo_name)
1276
1277def test_safetensors_on_the_fly_conversion(self):
1278config = BertConfig(
1279vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1280)
1281initial_model = BertModel(config)
1282
1283initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
1284converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
1285
1286with self.subTest("Initial and converted models are equal"):
1287for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1288self.assertTrue(torch.equal(p1, p2))
1289
1290with self.subTest("PR was open with the safetensors account"):
1291discussions = self.api.get_repo_discussions(self.repo_name)
1292discussion = next(discussions)
1293self.assertEqual(discussion.author, "SFconvertbot")
1294self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1295
1296def test_safetensors_on_the_fly_conversion_private(self):
1297config = BertConfig(
1298vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1299)
1300initial_model = BertModel(config)
1301
1302initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
1303converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1304
1305with self.subTest("Initial and converted models are equal"):
1306for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1307self.assertTrue(torch.equal(p1, p2))
1308
1309with self.subTest("PR was open with the safetensors account"):
1310discussions = self.api.get_repo_discussions(self.repo_name, token=self.token)
1311discussion = next(discussions)
1312self.assertEqual(discussion.author, self.user)
1313self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1314
1315def test_safetensors_on_the_fly_conversion_gated(self):
1316config = BertConfig(
1317vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1318)
1319initial_model = BertModel(config)
1320
1321initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
1322headers = {"Authorization": f"Bearer {self.token}"}
1323requests.put(
1324f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
1325)
1326converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1327
1328with self.subTest("Initial and converted models are equal"):
1329for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1330self.assertTrue(torch.equal(p1, p2))
1331
1332with self.subTest("PR was open with the safetensors account"):
1333discussions = self.api.get_repo_discussions(self.repo_name)
1334discussion = next(discussions)
1335self.assertEqual(discussion.author, "SFconvertbot")
1336self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1337
1338def test_safetensors_on_the_fly_sharded_conversion(self):
1339config = BertConfig(
1340vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1341)
1342initial_model = BertModel(config)
1343
1344initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb")
1345converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
1346
1347with self.subTest("Initial and converted models are equal"):
1348for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1349self.assertTrue(torch.equal(p1, p2))
1350
1351with self.subTest("PR was open with the safetensors account"):
1352discussions = self.api.get_repo_discussions(self.repo_name)
1353discussion = next(discussions)
1354self.assertEqual(discussion.author, "SFconvertbot")
1355self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1356
1357def test_safetensors_on_the_fly_sharded_conversion_private(self):
1358config = BertConfig(
1359vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1360)
1361initial_model = BertModel(config)
1362
1363initial_model.push_to_hub(
1364self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb", private=True
1365)
1366converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1367
1368with self.subTest("Initial and converted models are equal"):
1369for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1370self.assertTrue(torch.equal(p1, p2))
1371
1372with self.subTest("PR was open with the safetensors account"):
1373discussions = self.api.get_repo_discussions(self.repo_name)
1374discussion = next(discussions)
1375self.assertEqual(discussion.author, self.user)
1376self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1377
1378def test_safetensors_on_the_fly_sharded_conversion_gated(self):
1379config = BertConfig(
1380vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1381)
1382initial_model = BertModel(config)
1383
1384initial_model.push_to_hub(self.repo_name, token=self.token, max_shard_size="200kb", safe_serialization=False)
1385headers = {"Authorization": f"Bearer {self.token}"}
1386requests.put(
1387f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
1388)
1389converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1390
1391with self.subTest("Initial and converted models are equal"):
1392for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
1393self.assertTrue(torch.equal(p1, p2))
1394
1395with self.subTest("PR was open with the safetensors account"):
1396discussions = self.api.get_repo_discussions(self.repo_name)
1397discussion = next(discussions)
1398self.assertEqual(discussion.author, "SFconvertbot")
1399self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1400
1401@unittest.skip("Edge case, should work once the Space is updated`")
1402def test_safetensors_on_the_fly_wrong_user_opened_pr(self):
1403config = BertConfig(
1404vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1405)
1406initial_model = BertModel(config)
1407
1408initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
1409BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1410
1411# This should have opened a PR with the user's account
1412with self.subTest("PR was open with the safetensors account"):
1413discussions = self.api.get_repo_discussions(self.repo_name)
1414discussion = next(discussions)
1415self.assertEqual(discussion.author, self.user)
1416self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
1417
1418# We now switch the repo visibility to public
1419self.api.update_repo_visibility(self.repo_name, private=False)
1420
1421# We once again call from_pretrained, which should call the bot to open a PR
1422BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
1423
1424with self.subTest("PR was open with the safetensors account"):
1425discussions = self.api.get_repo_discussions(self.repo_name)
1426
1427bot_opened_pr = None
1428bot_opened_pr_title = None
1429
1430for discussion in discussions:
1431if discussion.author == "SFconvertBot":
1432bot_opened_pr = True
1433bot_opened_pr_title = discussion.title
1434
1435self.assertTrue(bot_opened_pr)
1436self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")
1437
1438def test_safetensors_on_the_fly_specific_revision(self):
1439config = BertConfig(
1440vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1441)
1442initial_model = BertModel(config)
1443
1444# Push a model on `main`
1445initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
1446
1447# Push a model on a given revision
1448initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, revision="new-branch")
1449
1450# Try to convert the model on that revision should raise
1451with self.assertRaises(EnvironmentError):
1452BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")
1453
1454
1455@require_torch
1456@is_staging_test
1457class ModelPushToHubTester(unittest.TestCase):
1458@classmethod
1459def setUpClass(cls):
1460cls._token = TOKEN
1461HfFolder.save_token(TOKEN)
1462
1463@classmethod
1464def tearDownClass(cls):
1465try:
1466delete_repo(token=cls._token, repo_id="test-model")
1467except HTTPError:
1468pass
1469
1470try:
1471delete_repo(token=cls._token, repo_id="valid_org/test-model-org")
1472except HTTPError:
1473pass
1474
1475try:
1476delete_repo(token=cls._token, repo_id="test-dynamic-model")
1477except HTTPError:
1478pass
1479
1480try:
1481delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
1482except HTTPError:
1483pass
1484
1485@unittest.skip("This test is flaky")
1486def test_push_to_hub(self):
1487config = BertConfig(
1488vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1489)
1490model = BertModel(config)
1491model.push_to_hub("test-model", token=self._token)
1492
1493new_model = BertModel.from_pretrained(f"{USER}/test-model")
1494for p1, p2 in zip(model.parameters(), new_model.parameters()):
1495self.assertTrue(torch.equal(p1, p2))
1496
1497# Reset repo
1498delete_repo(token=self._token, repo_id="test-model")
1499
1500# Push to hub via save_pretrained
1501with tempfile.TemporaryDirectory() as tmp_dir:
1502model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, token=self._token)
1503
1504new_model = BertModel.from_pretrained(f"{USER}/test-model")
1505for p1, p2 in zip(model.parameters(), new_model.parameters()):
1506self.assertTrue(torch.equal(p1, p2))
1507
1508def test_push_to_hub_with_description(self):
1509config = BertConfig(
1510vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1511)
1512model = BertModel(config)
1513COMMIT_DESCRIPTION = """
1514The commit description supports markdown synthax see:
1515```python
1516>>> form transformers import AutoConfig
1517>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
1518```
1519"""
1520commit_details = model.push_to_hub(
1521"test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION
1522)
1523self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)
1524
1525@unittest.skip("This test is flaky")
1526def test_push_to_hub_in_organization(self):
1527config = BertConfig(
1528vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1529)
1530model = BertModel(config)
1531model.push_to_hub("valid_org/test-model-org", token=self._token)
1532
1533new_model = BertModel.from_pretrained("valid_org/test-model-org")
1534for p1, p2 in zip(model.parameters(), new_model.parameters()):
1535self.assertTrue(torch.equal(p1, p2))
1536
1537# Reset repo
1538delete_repo(token=self._token, repo_id="valid_org/test-model-org")
1539
1540# Push to hub via save_pretrained
1541with tempfile.TemporaryDirectory() as tmp_dir:
1542model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-org")
1543
1544new_model = BertModel.from_pretrained("valid_org/test-model-org")
1545for p1, p2 in zip(model.parameters(), new_model.parameters()):
1546self.assertTrue(torch.equal(p1, p2))
1547
1548def test_push_to_hub_dynamic_model(self):
1549CustomConfig.register_for_auto_class()
1550CustomModel.register_for_auto_class()
1551
1552config = CustomConfig(hidden_size=32)
1553model = CustomModel(config)
1554
1555model.push_to_hub("test-dynamic-model", token=self._token)
1556# checks
1557self.assertDictEqual(
1558config.auto_map,
1559{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
1560)
1561
1562new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
1563# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
1564self.assertEqual(new_model.__class__.__name__, "CustomModel")
1565for p1, p2 in zip(model.parameters(), new_model.parameters()):
1566self.assertTrue(torch.equal(p1, p2))
1567
1568config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
1569new_model = AutoModel.from_config(config, trust_remote_code=True)
1570self.assertEqual(new_model.__class__.__name__, "CustomModel")
1571
1572def test_push_to_hub_with_tags(self):
1573from huggingface_hub import ModelCard
1574
1575new_tags = ["tag-1", "tag-2"]
1576
1577CustomConfig.register_for_auto_class()
1578CustomModel.register_for_auto_class()
1579
1580config = CustomConfig(hidden_size=32)
1581model = CustomModel(config)
1582
1583self.assertTrue(model.model_tags is None)
1584
1585model.add_model_tags(new_tags)
1586
1587self.assertTrue(model.model_tags == new_tags)
1588
1589model.push_to_hub("test-dynamic-model-with-tags", token=self._token)
1590
1591loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
1592self.assertEqual(loaded_model_card.data.tags, new_tags)
1593
1594
1595@require_torch
1596class AttentionMaskTester(unittest.TestCase):
1597def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
1598mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
1599mask_4d_values = mask_4d[:, 0][mask_indices]
1600is_inf = mask_4d_values == -float("inf")
1601is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
1602assert torch.logical_or(is_inf, is_min).all()
1603
1604def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
1605mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
1606
1607if additional_mask is not None:
1608for bsz_idx, seq_idx in additional_mask:
1609mask_2d[bsz_idx, seq_idx] = 0
1610
1611mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32)
1612
1613assert mask_4d.shape == (bsz, 1, q_len, kv_len)
1614
1615# make sure there are no overflows
1616assert mask_4d.min() != float("-inf")
1617
1618context = mask_converter.sliding_window
1619if mask_converter.is_causal and context is None:
1620# k * (k+1) / 2 tokens are masked in triangualar masks
1621num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
1622
1623if 0 not in mask_2d:
1624assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1625if 0 in mask_2d:
1626# at least causal mask + maybe more
1627assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
1628self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
1629elif not mask_converter.is_causal and context is None:
1630if 0 not in mask_2d:
1631assert (mask_4d != 0).sum().cpu().item() == 0
1632if 0 in mask_2d:
1633self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
1634elif mask_converter.is_causal and context is not None:
1635# k * (k+1) / 2 tokens are masked in triangualar masks
1636num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
1637num_tokens_masked = bsz * num_tokens_masked
1638
1639if 0 not in mask_2d:
1640assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1641if 0 in mask_2d:
1642# at least causal mask + maybe more
1643assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
1644self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
1645
1646def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
1647mask_4d = mask_converter.to_causal_4d(
1648bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32
1649)
1650
1651if q_len == 1 and mask_converter.sliding_window is None:
1652# no causal mask if q_len is 1
1653assert mask_4d is None
1654return
1655
1656context = mask_converter.sliding_window
1657if mask_converter.is_causal and context is None:
1658# k * (k+1) / 2 tokens are masked in triangualar masks
1659num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
1660
1661assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1662elif not mask_converter.is_causal and context is None:
1663assert (mask_4d != 0).sum().cpu().item() == 0
1664elif mask_converter.is_causal and context is not None:
1665# k * (k+1) / 2 tokens are masked in triangualar masks
1666num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
1667num_tokens_masked = bsz * num_tokens_masked
1668
1669assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
1670
1671def compute_num_context_mask(self, kv_len, context, q_len):
1672# This function computes the # of attention tokens that are added for
1673# the sliding window
1674c_mask_len = kv_len - context
1675num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
1676cut_mask_len = max(c_mask_len - q_len, 0)
1677num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
1678return num_mask_triangle - num_cut_mask
1679
1680def test_2d_to_4d_causal(self):
1681mask_converter = AttentionMaskConverter(is_causal=True)
1682
1683# auto-regressive use case
1684self.check_to_4d(mask_converter, q_len=1, kv_len=7)
1685# special auto-regressive case
1686self.check_to_4d(mask_converter, q_len=3, kv_len=7)
1687# non auto-regressive case
1688self.check_to_4d(mask_converter, q_len=7, kv_len=7)
1689
1690# same with extra attention masks
1691self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1692self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1693self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1694
1695# check that the mask does not overflow on causal masked tokens
1696self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 0), (1, 0), (1, 1)])
1697
1698def test_2d_to_4d(self):
1699mask_converter = AttentionMaskConverter(is_causal=False)
1700
1701# non auto-regressive case
1702self.check_to_4d(mask_converter, q_len=7, kv_len=7)
1703
1704# same with extra attention masks
1705self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1706
1707def test_2d_to_4d_causal_sliding(self):
1708mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=5)
1709
1710# auto-regressive use case
1711self.check_to_4d(mask_converter, q_len=1, kv_len=7)
1712# special auto-regressive case
1713self.check_to_4d(mask_converter, q_len=3, kv_len=7)
1714# non auto-regressive case
1715self.check_to_4d(mask_converter, q_len=7, kv_len=7)
1716
1717# same with extra attention masks
1718self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1719self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1720self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
1721
1722def test_causal_mask(self):
1723mask_converter = AttentionMaskConverter(is_causal=True)
1724
1725# auto-regressive use case
1726self.check_to_causal(mask_converter, q_len=1, kv_len=7)
1727# special auto-regressive case
1728self.check_to_causal(mask_converter, q_len=3, kv_len=7)
1729# non auto-regressive case
1730self.check_to_causal(mask_converter, q_len=7, kv_len=7)
1731
1732def test_causal_mask_sliding(self):
1733mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=3)
1734
1735# auto-regressive use case
1736self.check_to_causal(mask_converter, q_len=1, kv_len=7)
1737# special auto-regressive case
1738self.check_to_causal(mask_converter, q_len=3, kv_len=7)
1739# non auto-regressive case
1740self.check_to_causal(mask_converter, q_len=7, kv_len=7)
1741
1742def test_torch_compile_fullgraph(self):
1743model = Prepare4dCausalAttentionMaskModel()
1744
1745inputs_embeds = torch.rand([1, 3, 32])
1746res_non_compiled = model(inputs_embeds)
1747
1748compiled_model = torch.compile(model, fullgraph=True)
1749
1750res_compiled = compiled_model(inputs_embeds)
1751
1752self.assertTrue(torch.equal(res_non_compiled, res_compiled))
1753
1754model = Create4dCausalAttentionMaskModel()
1755
1756inputs_embeds = torch.rand(2, 4, 16)
1757res_non_compiled = model(inputs_embeds)
1758
1759compiled_model = torch.compile(model, fullgraph=True)
1760res_compiled = compiled_model(inputs_embeds)
1761
1762self.assertTrue(torch.equal(res_non_compiled, res_compiled))
1763
1764model = Prepare4dAttentionMaskModel()
1765
1766mask = torch.ones(2, 4)
1767mask[0, :2] = 0
1768inputs_embeds = torch.rand(2, 4, 16)
1769
1770res_non_compiled = model(mask, inputs_embeds)
1771
1772compiled_model = torch.compile(model, fullgraph=True)
1773res_compiled = compiled_model(mask, inputs_embeds)
1774
1775self.assertTrue(torch.equal(res_non_compiled, res_compiled))
1776
1777@require_torch
1778@slow
1779def test_unmask_unattended_left_padding(self):
1780attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
1781
1782expanded_mask = torch.Tensor(
1783[
1784[[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
1785[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
1786[[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
1787]
1788).to(torch.int64)
1789
1790reference_output = torch.Tensor(
1791[
1792[[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
1793[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
1794[[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
1795]
1796).to(torch.int64)
1797
1798result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
1799
1800self.assertTrue(torch.equal(result, reference_output))
1801
1802attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
1803
1804attn_mask_converter = AttentionMaskConverter(is_causal=True)
1805past_key_values_length = 0
1806key_value_length = attention_mask.shape[-1] + past_key_values_length
1807
1808expanded_mask = attn_mask_converter.to_4d(
1809attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
1810)
1811
1812result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
1813min_inf = torch.finfo(torch.float32).min
1814reference_output = torch.Tensor(
1815[
1816[
1817[
1818[0, 0, 0, 0, 0],
1819[0, 0, 0, 0, 0],
1820[min_inf, min_inf, 0, min_inf, min_inf],
1821[min_inf, min_inf, 0, 0, min_inf],
1822[min_inf, min_inf, 0, 0, 0],
1823]
1824],
1825[
1826[
1827[0, min_inf, min_inf, min_inf, min_inf],
1828[0, 0, min_inf, min_inf, min_inf],
1829[0, 0, 0, min_inf, min_inf],
1830[0, 0, 0, 0, min_inf],
1831[0, 0, 0, 0, 0],
1832]
1833],
1834[
1835[
1836[0, 0, 0, 0, 0],
1837[min_inf, 0, min_inf, min_inf, min_inf],
1838[min_inf, 0, 0, min_inf, min_inf],
1839[min_inf, 0, 0, 0, min_inf],
1840[min_inf, 0, 0, 0, 0],
1841]
1842],
1843]
1844)
1845
1846self.assertTrue(torch.equal(reference_output, result))
1847
1848@require_torch
1849@slow
1850def test_unmask_unattended_right_padding(self):
1851attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
1852
1853attn_mask_converter = AttentionMaskConverter(is_causal=True)
1854past_key_values_length = 0
1855key_value_length = attention_mask.shape[-1] + past_key_values_length
1856
1857expanded_mask = attn_mask_converter.to_4d(
1858attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
1859)
1860
1861result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
1862
1863self.assertTrue(torch.equal(expanded_mask, result))
1864
1865@require_torch
1866@slow
1867def test_unmask_unattended_random_mask(self):
1868attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
1869
1870attn_mask_converter = AttentionMaskConverter(is_causal=True)
1871past_key_values_length = 0
1872key_value_length = attention_mask.shape[-1] + past_key_values_length
1873
1874expanded_mask = attn_mask_converter.to_4d(
1875attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
1876)
1877
1878result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
1879
1880self.assertTrue(torch.equal(expanded_mask, result))
1881
1882
1883@require_torch
1884class TestAttentionImplementation(unittest.TestCase):
1885def test_error_no_sdpa_available(self):
1886with self.assertRaises(ValueError) as cm:
1887_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
1888
1889self.assertTrue(
1890"does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
1891in str(cm.exception)
1892)
1893
1894_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
1895
1896def test_error_no_flash_available(self):
1897with self.assertRaises(ValueError) as cm:
1898_ = AutoModel.from_pretrained(
1899"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
1900)
1901
1902self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
1903
1904def test_error_no_flash_available_with_config(self):
1905with self.assertRaises(ValueError) as cm:
1906config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
1907
1908_ = AutoModel.from_pretrained(
1909"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
1910)
1911
1912self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
1913
1914def test_error_wrong_attn_implementation(self):
1915with self.assertRaises(ValueError) as cm:
1916_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
1917
1918self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
1919
1920def test_not_available_flash(self):
1921if is_flash_attn_2_available():
1922self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
1923
1924with self.assertRaises(ImportError) as cm:
1925_ = AutoModel.from_pretrained(
1926"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
1927)
1928
1929self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
1930
1931def test_not_available_flash_with_config(self):
1932if is_flash_attn_2_available():
1933self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
1934
1935config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
1936
1937with self.assertRaises(ImportError) as cm:
1938_ = AutoModel.from_pretrained(
1939"hf-internal-testing/tiny-random-GPTBigCodeModel",
1940config=config,
1941attn_implementation="flash_attention_2",
1942)
1943
1944self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
1945
1946def test_not_available_sdpa(self):
1947if is_torch_sdpa_available():
1948self.skipTest("This test requires torch<=2.0")
1949
1950with self.assertRaises(ImportError) as cm:
1951_ = AutoModel.from_pretrained(
1952"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
1953)
1954
1955self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
1956
1957
1958@slow
1959@require_torch_gpu
1960class Mask4DTestBase(unittest.TestCase):
1961def tearDown(self):
1962gc.collect()
1963torch.cuda.empty_cache()
1964
1965def get_test_data(self):
1966texts = ["the cat sat", "the cat had", "the cat is"]
1967encoded = [self.tokenizer.encode(t) for t in texts]
1968input_0 = torch.tensor(encoded, device=torch_device)
1969# tensor([[ 1, 278, 6635, 3290],
1970# [ 1, 278, 6635, 750],
1971# [ 1, 278, 6635, 338]], device='cuda:0')
1972
1973# Combining common prefix with the unique ending tokens:
1974input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
1975# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
1976
1977# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
1978mask_1 = torch.tensor(
1979[
1980[
1981[
1982[1, 0, 0, 0, 0, 0],
1983[1, 1, 0, 0, 0, 0],
1984[1, 1, 1, 0, 0, 0],
1985[1, 1, 1, 1, 0, 0],
1986[1, 1, 1, 0, 1, 0],
1987[1, 1, 1, 0, 0, 1],
1988]
1989]
1990],
1991device="cuda:0",
1992dtype=torch.int64,
1993)
1994
1995# Creating a position_ids tensor. note the repeating figures in the end.
1996position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
1997
1998return input_0, input_1, mask_1, position_ids_1
1999
2000
2001@slow
2002@require_torch_gpu
2003class Mask4DTestFP32(Mask4DTestBase):
2004def setUp(self):
2005model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
2006model_dtype = torch.float32
2007self.tokenizer = AutoTokenizer.from_pretrained(model_name)
2008self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
2009
2010def test_attention(self):
2011"""comparing outputs of attention layer"""
2012input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2013
2014hid_0 = self.model.model.embed_tokens(input_0)
2015outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
2016# outs_0.shape == torch.Size([3, 4, 768])
2017
2018hid_1 = self.model.model.embed_tokens(input_1)
2019outs_1 = self.model.model.layers[0].self_attn.forward(
2020hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
2021)[0]
2022# outs_1.shape == torch.Size([1, 6, 768])
2023
2024outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
2025outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
2026assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens)
2027
2028def test_inner_model(self):
2029"""comparing hidden outputs of whole inner model"""
2030input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2031
2032logits_0 = self.model.forward(input_0).logits
2033logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
2034
2035logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
2036logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
2037torch.testing.assert_close(
2038logits_0_last_tokens,
2039logits_1_last_tokens,
2040)
2041
2042def test_causal_model_logits(self):
2043"""comparing logits outputs of whole inner model"""
2044input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2045
2046logits_0 = self.model.forward(input_0).logits
2047logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
2048
2049logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
2050logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
2051torch.testing.assert_close(
2052logits_0_last_tokens,
2053logits_1_last_tokens,
2054)
2055
2056
2057@slow
2058@require_torch_gpu
2059class Mask4DTestFP16(Mask4DTestBase):
2060test_attention = Mask4DTestFP32.test_attention
2061
2062def setUp(self):
2063model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
2064model_dtype = torch.float16
2065self.tokenizer = AutoTokenizer.from_pretrained(model_name)
2066self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
2067
2068def test_causal_model_logits(self):
2069"""comparing logits outputs of whole inner model"""
2070input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
2071
2072logits_0 = self.model.forward(input_0).logits
2073logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
2074
2075logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
2076logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
2077
2078indices_0 = logits_0_last_tokens.sort(descending=True).indices
2079indices_1 = logits_1_last_tokens.sort(descending=True).indices
2080
2081# checking logits, but note relaxed tolerances for FP16
2082torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
2083
2084# checking tokens order for the top tokens
2085for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
2086self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
2087