transformers

Форк
0
/
test_modeling_tf_common.py 
1874 строки · 89.2 Кб
1
# coding=utf-8
2
# Copyright 2019 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
from __future__ import annotations
18

19
import copy
20
import inspect
21
import json
22
import os
23
import random
24
import tempfile
25
import unittest
26
from importlib import import_module
27
from math import isnan
28
from typing import List, Tuple
29

30
from datasets import Dataset
31

32
from transformers import is_tf_available, is_torch_available
33
from transformers.models.auto import get_values
34
from transformers.testing_utils import (  # noqa: F401
35
    CaptureLogger,
36
    _tf_gpu_memory_limit,
37
    is_pt_tf_cross_test,
38
    require_tf,
39
    require_tf2onnx,
40
    slow,
41
    torch_device,
42
)
43
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
44
from transformers.utils.generic import ModelOutput
45

46

47
logger = logging.get_logger(__name__)
48

49

50
if is_tf_available():
51
    import numpy as np
52
    import tensorflow as tf
53

54
    from transformers import (
55
        TF_MODEL_FOR_CAUSAL_LM_MAPPING,
56
        TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
57
        TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
58
        TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
59
        TF_MODEL_FOR_MASKED_LM_MAPPING,
60
        TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
61
        TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
62
        TF_MODEL_FOR_PRETRAINING_MAPPING,
63
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
64
        TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
65
        TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
66
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
67
        TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
68
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
69
        TFAutoModel,
70
        TFAutoModelForSequenceClassification,
71
        TFSharedEmbeddings,
72
    )
73
    from transformers.generation import (
74
        TFBeamSampleDecoderOnlyOutput,
75
        TFBeamSampleEncoderDecoderOutput,
76
        TFBeamSearchDecoderOnlyOutput,
77
        TFBeamSearchEncoderDecoderOutput,
78
        TFGreedySearchDecoderOnlyOutput,
79
        TFGreedySearchEncoderDecoderOutput,
80
        TFSampleDecoderOnlyOutput,
81
        TFSampleEncoderDecoderOutput,
82
    )
83
    from transformers.modeling_tf_utils import keras
84

85
    tf.config.experimental.enable_tensor_float_32_execution(False)
86

87
    if _tf_gpu_memory_limit is not None:
88
        gpus = tf.config.list_physical_devices("GPU")
89
        for gpu in gpus:
90
            # Restrict TensorFlow to only allocate x GB of memory on the GPUs
91
            try:
92
                tf.config.set_logical_device_configuration(
93
                    gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
94
                )
95
                logical_gpus = tf.config.list_logical_devices("GPU")
96
                print("Logical GPUs", logical_gpus)
97
            except RuntimeError as e:
98
                # Virtual devices must be set before GPUs have been initialized
99
                print(e)
100

101
if is_torch_available():
102
    import torch
103

104

105
def _config_zero_init(config):
106
    configs_no_init = copy.deepcopy(config)
107
    for key in configs_no_init.__dict__.keys():
108
        if "_range" in key or "_std" in key:
109
            setattr(configs_no_init, key, 0.0)
110
    return configs_no_init
111

112

113
@require_tf
114
class TFModelTesterMixin:
115
    model_tester = None
116
    all_model_classes = ()
117
    all_generative_model_classes = ()
118
    test_mismatched_shapes = True
119
    test_resize_embeddings = True
120
    test_head_masking = True
121
    is_encoder_decoder = False
122
    has_attentions = True
123

124
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
125
        inputs_dict = copy.deepcopy(inputs_dict)
126

127
        if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
128
            inputs_dict = {
129
                k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
130
                if isinstance(v, tf.Tensor) and v.ndim > 0
131
                else v
132
                for k, v in inputs_dict.items()
133
            }
134

135
        if return_labels:
136
            if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
137
                inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
138
            elif model_class in [
139
                *get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING),
140
                *get_values(TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING),
141
            ]:
142
                inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
143
                inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
144
            elif model_class in [
145
                *get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
146
                *get_values(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
147
            ]:
148
                inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
149
            elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
150
                inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
151
            elif model_class in [
152
                *get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
153
                *get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
154
                *get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
155
                *get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
156
                *get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
157
                *get_values(TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING),
158
            ] and "labels" in dict(inspect.signature(model_class.call).parameters):
159
                inputs_dict["labels"] = tf.zeros(
160
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
161
                )
162
            elif model_class in get_values(TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
163
                num_patches = self.model_tester.image_size // self.model_tester.patch_size
164
                inputs_dict["bool_masked_pos"] = tf.zeros(
165
                    (self.model_tester.batch_size, num_patches**2), dtype=tf.int32
166
                )
167
            elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING):
168
                batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
169
                inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32)
170
            elif model_class.__name__.endswith("ForCTC"):
171
                # When we have enough CTC models for an AutoClass, we should use their mapping instead of name checks
172
                inputs_dict["labels"] = tf.zeros(
173
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
174
                )
175

176
        return inputs_dict
177

178
    def test_initialization(self):
179
        pass
180

181
    def test_save_load(self):
182
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
183

184
        for model_class in self.all_model_classes:
185
            model = model_class(config)
186
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
187

188
            with tempfile.TemporaryDirectory() as tmpdirname:
189
                model.save_pretrained(tmpdirname, saved_model=False)
190

191
                # the config file (and the generation config file, if it can generate) should be saved
192
                self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
193
                self.assertEqual(
194
                    model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
195
                )
196

197
                model = model_class.from_pretrained(tmpdirname)
198
                after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
199

200
                self.assert_outputs_same(after_outputs, outputs)
201

202
    def test_save_load_config(self):
203
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
204

205
        for model_class in self.all_model_classes:
206
            model = model_class(config)
207
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
208
            model_config = model.get_config()
209
            # make sure that returned config is jsonifiable, which is required by keras
210
            json.dumps(model_config)
211
            new_model = model_class.from_config(model.get_config())
212
            # make sure it also accepts a normal config
213
            _ = model_class.from_config(model.config)
214
            _ = new_model(self._prepare_for_class(inputs_dict, model_class))  # Build model
215
            new_model.set_weights(model.get_weights())
216
            after_outputs = new_model(self._prepare_for_class(inputs_dict, model_class))
217

218
            self.assert_outputs_same(after_outputs, outputs)
219

220
    @slow
221
    def test_saved_model_creation(self):
222
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
223
        config.output_hidden_states = False
224
        config.output_attentions = False
225

226
        if hasattr(config, "use_cache"):
227
            config.use_cache = False
228

229
        model_class = self.all_model_classes[0]
230

231
        class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
232
        model = model_class(config)
233

234
        model(class_inputs_dict)
235

236
        with tempfile.TemporaryDirectory() as tmpdirname:
237
            model.save_pretrained(tmpdirname, saved_model=True)
238
            saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
239
            self.assertTrue(os.path.exists(saved_model_dir))
240

241
    def test_prepare_serving_output(self):
242
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
243
        config.output_hidden_states = True
244
        config.output_attentions = self.has_attentions
245

246
        for model_class in self.all_model_classes:
247
            model = model_class(config)
248
            inputs = self._prepare_for_class(inputs_dict, model_class)
249
            outputs = model(inputs)
250
            serving_outputs = model.serving_output(outputs)
251

252
            for k, v in serving_outputs.items():
253
                # Check that we have one of three possible outputs: None, tuple of tensors or a tensor
254
                if isinstance(v, tuple):
255
                    self.assertTrue(all(isinstance(elem, tf.Tensor) for elem in v))
256
                elif v is not None:
257
                    self.assertIsInstance(v, tf.Tensor)
258
                else:
259
                    self.assertIsNone(v)
260

261
    def test_forward_signature(self):
262
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
263

264
        for model_class in self.all_model_classes:
265
            model = model_class(config)
266
            signature = inspect.signature(model.call)
267
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
268
            arg_names = [*signature.parameters.keys()]
269

270
            if model.config.is_encoder_decoder:
271
                expected_arg_names = [
272
                    "input_ids",
273
                    "attention_mask",
274
                    "decoder_input_ids",
275
                    "decoder_attention_mask",
276
                ]
277
                expected_arg_names.extend(["decoder_position_ids"] if "decoder_position_ids" in arg_names else [])
278
                expected_arg_names.extend(
279
                    ["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
280
                )
281
                expected_arg_names.extend(
282
                    ["cross_attn_head_mask", "encoder_outputs"]
283
                    if "cross_attn_head_mask" in arg_names
284
                    else ["encoder_outputs"]
285
                )
286
                self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
287

288
            else:
289
                expected_arg_names = ["input_ids"]
290
                self.assertListEqual(arg_names[:1], expected_arg_names)
291

292
    def test_onnx_compliancy(self):
293
        if not self.test_onnx:
294
            return
295

296
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
297
        INTERNAL_OPS = [
298
            "Assert",
299
            "AssignVariableOp",
300
            "EmptyTensorList",
301
            "ReadVariableOp",
302
            "ResourceGather",
303
            "TruncatedNormal",
304
            "VarHandleOp",
305
            "VarIsInitializedOp",
306
        ]
307
        onnx_ops = []
308

309
        with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
310
            onnx_opsets = json.load(f)["opsets"]
311

312
        for i in range(1, self.onnx_min_opset + 1):
313
            onnx_ops.extend(onnx_opsets[str(i)])
314

315
        for model_class in self.all_model_classes:
316
            model_op_names = set()
317

318
            with tf.Graph().as_default() as g:
319
                model = model_class(config)
320
                model.build_in_name_scope()
321

322
                for op in g.get_operations():
323
                    model_op_names.add(op.node_def.op)
324

325
            model_op_names = sorted(model_op_names)
326
            incompatible_ops = []
327

328
            for op in model_op_names:
329
                if op not in onnx_ops and op not in INTERNAL_OPS:
330
                    incompatible_ops.append(op)
331

332
            self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
333

334
    # `tf2onnx` issue page: https://github.com/onnx/tensorflow-onnx/issues/2172
335
    # TODO: undo skip once a fix is done in `tf2onnx`
336
    @unittest.skip("`tf2onnx` broke with TF 2.13")
337
    @require_tf2onnx
338
    @slow
339
    def test_onnx_runtime_optimize(self):
340
        if not self.test_onnx:
341
            return
342

343
        import onnxruntime
344
        import tf2onnx
345

346
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
347

348
        for model_class in self.all_model_classes[:2]:
349
            model = model_class(config)
350
            model.build_in_name_scope()
351

352
            onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
353

354
            onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
355

356
    def test_keras_save_load(self):
357
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
358

359
        tf_main_layer_classes = {
360
            module_member
361
            for model_class in self.all_model_classes
362
            for module in (import_module(model_class.__module__),)
363
            for module_member_name in dir(module)
364
            if module_member_name.endswith("MainLayer")
365
            # This condition is required, since `modeling_tf_clip.py` has 3 classes whose names end with `MainLayer`.
366
            and module_member_name[: -len("MainLayer")] == model_class.__name__[: -len("Model")]
367
            for module_member in (getattr(module, module_member_name),)
368
            if isinstance(module_member, type)
369
            and keras.layers.Layer in module_member.__bases__
370
            and getattr(module_member, "_keras_serializable", False)
371
        }
372
        for main_layer_class in tf_main_layer_classes:
373
            # T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
374
            if "T5" in main_layer_class.__name__:
375
                # Take the same values than in TFT5ModelTester for this shared layer
376
                shared = TFSharedEmbeddings(99, 32, name="shared")
377
                config.use_cache = inputs_dict.pop("use_cache", None)
378
                main_layer = main_layer_class(config, embed_tokens=shared)
379
            else:
380
                main_layer = main_layer_class(config)
381

382
            symbolic_inputs = {
383
                name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
384
            }
385

386
            model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
387
            outputs = model(inputs_dict)
388

389
            with tempfile.TemporaryDirectory() as tmpdirname:
390
                filepath = os.path.join(tmpdirname, "keras_model.h5")
391
                model.save(filepath)
392
                if "T5" in main_layer_class.__name__:
393
                    model = keras.models.load_model(
394
                        filepath,
395
                        custom_objects={
396
                            main_layer_class.__name__: main_layer_class,
397
                            "TFSharedEmbeddings": TFSharedEmbeddings,
398
                        },
399
                    )
400
                else:
401
                    model = keras.models.load_model(
402
                        filepath, custom_objects={main_layer_class.__name__: main_layer_class}
403
                    )
404
                assert isinstance(model, keras.Model)
405
                after_outputs = model(inputs_dict)
406
                self.assert_outputs_same(after_outputs, outputs)
407

408
    def assert_outputs_same(self, after_outputs, outputs):
409
        # Make sure we don't have nans
410
        if isinstance(after_outputs, tf.Tensor):
411
            out_1 = after_outputs.numpy()
412
        elif isinstance(after_outputs, dict):
413
            out_1 = after_outputs[list(after_outputs.keys())[0]].numpy()
414
        else:
415
            out_1 = after_outputs[0].numpy()
416
        out_2 = outputs[0].numpy()
417
        self.assertEqual(out_1.shape, out_2.shape)
418
        out_1 = out_1[~np.isnan(out_1)]
419
        out_2 = out_2[~np.isnan(out_2)]
420
        max_diff = np.amax(np.abs(out_1 - out_2))
421
        self.assertLessEqual(max_diff, 1e-5)
422

423
    # Don't copy this method to model specific test file!
424
    # TODO: remove this method once the issues are all fixed!
425
    def _make_attention_mask_non_null(self, inputs_dict):
426
        """Make sure no sequence has all zeros as attention mask"""
427

428
        for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
429
            if k in inputs_dict:
430
                attention_mask = inputs_dict[k]
431

432
                # Make sure no all 0s attention masks - to avoid failure at this moment.
433
                # Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
434
                # TODO: remove this line once a fix regarding large negative values for attention mask is done.
435
                attention_mask = tf.concat(
436
                    [tf.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], axis=-1
437
                )
438

439
                # Here we make the first sequence with all 0s as attention mask.
440
                # Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
441
                # values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
442
                # TODO: enable this block once the large negative values thing is cleaned up.
443
                # (see https://github.com/huggingface/transformers/issues/14859)
444
                # attention_mask = tf.concat(
445
                #     [
446
                #         tf.zeros_like(attention_mask[:1], dtype=tf.int32),
447
                #         tf.cast(attention_mask[1:], dtype=tf.int32)
448
                #     ],
449
                #     axis=0
450
                # )
451

452
                inputs_dict[k] = attention_mask
453

454
    # Don't copy this method to model specific test file!
455
    # TODO: remove this method once the issues are all fixed!
456
    def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class):
457
        """For temporarily ignoring some failed test cases (issues to be fixed)"""
458

459
        tf_keys = {k for k, v in tf_outputs.items() if v is not None}
460
        pt_keys = {k for k, v in pt_outputs.items() if v is not None}
461

462
        key_differences = tf_keys.symmetric_difference(pt_keys)
463

464
        if model_class.__name__ in [
465
            "TFFlaubertWithLMHeadModel",
466
            "TFFunnelForPreTraining",
467
            "TFElectraForPreTraining",
468
            "TFXLMWithLMHeadModel",
469
        ]:
470
            for k in key_differences:
471
                if k in ["loss", "losses"]:
472
                    tf_keys.discard(k)
473
                    pt_keys.discard(k)
474
        elif model_class.__name__.startswith("TFGPT2"):
475
            # `TFGPT2` has `past_key_values` as a tensor while `GPT2` has it as a tuple.
476
            tf_keys.discard("past_key_values")
477
            pt_keys.discard("past_key_values")
478

479
        # create new outputs from the remaining fields
480
        new_tf_outputs = type(tf_outputs)(**{k: tf_outputs[k] for k in tf_keys})
481
        new_pt_outputs = type(pt_outputs)(**{k: pt_outputs[k] for k in pt_keys})
482

483
        return new_tf_outputs, new_pt_outputs
484

485
    def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
486
        """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
487

488
        Args:
489
            model_class: The class of the model that is currently testing. For example, `TFBertModel`,
490
                TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
491
                error messages.
492
            name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc.
493
            attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
494
                being a named field in the output.
495
        """
496

497
        self.assertEqual(type(name), str)
498
        if attributes is not None:
499
            self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
500

501
        # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
502
        if isinstance(tf_outputs, ModelOutput):
503
            self.assertTrue(
504
                isinstance(pt_outputs, ModelOutput),
505
                f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is",
506
            )
507

508
            # Don't copy this block to model specific test file!
509
            # TODO: remove this method and this line after issues are fixed
510
            tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
511

512
            tf_keys = [k for k, v in tf_outputs.items() if v is not None]
513
            pt_keys = [k for k, v in pt_outputs.items() if v is not None]
514

515
            self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
516

517
            # convert to the case of `tuple`
518
            # appending each key to the current (string) `names`
519
            attributes = tuple([f"{name}.{k}" for k in tf_keys])
520
            self.check_pt_tf_outputs(
521
                tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
522
            )
523

524
        # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
525
        elif type(tf_outputs) in [tuple, list]:
526
            self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch")
527
            self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch")
528

529
            if attributes is not None:
530
                # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
531
                self.assertEqual(
532
                    len(attributes),
533
                    len(tf_outputs),
534
                    f"{name}: The tuple `names` should have the same length as `tf_outputs`",
535
                )
536
            else:
537
                # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
538
                attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
539

540
            for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
541
                self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
542

543
        elif isinstance(tf_outputs, tf.Tensor):
544
            self.assertTrue(
545
                isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is"
546
            )
547

548
            tf_outputs = tf_outputs.numpy()
549
            pt_outputs = pt_outputs.detach().to("cpu").numpy()
550

551
            self.assertEqual(
552
                tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch"
553
            )
554

555
            # deal with NumPy's scalars to make replacing nan values by 0 work.
556
            if np.isscalar(tf_outputs):
557
                tf_outputs = np.array([tf_outputs])
558
                pt_outputs = np.array([pt_outputs])
559

560
            tf_nans = np.isnan(tf_outputs)
561
            pt_nans = np.isnan(pt_outputs)
562

563
            pt_outputs[tf_nans] = 0
564
            tf_outputs[tf_nans] = 0
565
            pt_outputs[pt_nans] = 0
566
            tf_outputs[pt_nans] = 0
567

568
            max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
569
            self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
570
        else:
571
            raise ValueError(
572
                "`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
573
                f" {type(tf_outputs)} instead."
574
            )
575

576
    def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
577
        pt_inputs_dict = {}
578
        for name, key in tf_inputs_dict.items():
579
            if isinstance(key, bool):
580
                pt_inputs_dict[name] = key
581
            elif name == "input_values":
582
                pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
583
            elif name == "pixel_values":
584
                pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
585
            elif name == "input_features":
586
                pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
587
            # other general float inputs
588
            elif tf_inputs_dict[name].dtype.is_floating:
589
                pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
590
            else:
591
                pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
592

593
        return pt_inputs_dict
594

595
    def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
596
        pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
597

598
        # send pytorch inputs to the correct device
599
        pt_inputs_dict = {
600
            k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
601
        }
602

603
        # send pytorch model to the correct device
604
        pt_model.to(torch_device)
605

606
        # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
607
        pt_model.eval()
608

609
        with torch.no_grad():
610
            pt_outputs = pt_model(**pt_inputs_dict)
611
        tf_outputs = tf_model(tf_inputs_dict)
612

613
        # tf models returned loss is usually a tensor rather than a scalar.
614
        # (see `hf_compute_loss`: it uses `keras.losses.Reduction.NONE`)
615
        # Change it here to a scalar to match PyTorch models' loss
616
        tf_loss = getattr(tf_outputs, "loss", None)
617
        if tf_loss is not None:
618
            tf_outputs.loss = tf.math.reduce_mean(tf_loss)
619

620
        self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
621

622
    @is_pt_tf_cross_test
623
    def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
624
        import transformers
625

626
        for model_class in self.all_model_classes:
627
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
628

629
            # Output all for aggressive testing
630
            config.output_hidden_states = True
631
            config.output_attentions = self.has_attentions
632

633
            # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
634
            # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
635
            # TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
636
            self._make_attention_mask_non_null(inputs_dict)
637

638
            pt_model_class_name = model_class.__name__[2:]  # Skip the "TF" at the beginning
639
            pt_model_class = getattr(transformers, pt_model_class_name)
640

641
            tf_model = model_class(config)
642
            pt_model = pt_model_class(config)
643

644
            tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
645
            tf_inputs_dict_with_labels = self._prepare_for_class(
646
                inputs_dict,
647
                model_class,
648
                # Not all models accept "labels" in the forward pass (yet :) )
649
                return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
650
            )
651

652
            # For some models (e.g. base models), there is no label returned.
653
            # Set the input dict to `None` to avoid check outputs twice for the same input dicts.
654
            if not set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
655
                tf_inputs_dict_with_labels = None
656

657
            # Check we can load pt model in tf and vice-versa with model => model functions
658
            tf_model = transformers.load_pytorch_model_in_tf2_model(
659
                tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
660
            )
661
            pt_model = transformers.load_tf2_model_in_pytorch_model(
662
                pt_model, tf_model, allow_missing_keys=allow_missing_keys
663
            )
664

665
            # Original test: check without `labels`
666
            self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
667
            # check with `labels`
668
            if tf_inputs_dict_with_labels:
669
                self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels)
670

671
            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
672
            with tempfile.TemporaryDirectory() as tmpdirname:
673
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
674
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
675
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
676
                    tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
677
                )
678

679
                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
680
                tf_model.save_weights(tf_checkpoint_path)
681
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
682
                    pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
683
                )
684

685
            # Original test: check without `labels`
686
            self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
687
            # check with `labels`
688
            if tf_inputs_dict_with_labels:
689
                self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels)
690

691
    @slow
692
    def test_compile_tf_model(self):
693
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
694

695
        for model_class in self.all_model_classes[:2]:
696
            # Prepare our model
697
            model = model_class(config)
698
            # These are maximally general inputs for the model, with multiple None dimensions
699
            # Hopefully this will catch any conditionals that fail for flexible shapes
700
            functional_inputs = {
701
                key: keras.Input(shape=val.shape[1:], dtype=val.dtype, name=key)
702
                for key, val in model.input_signature.items()
703
                if key in model.dummy_inputs
704
            }
705
            outputs_dict = model(functional_inputs)
706

707
            hidden_states = outputs_dict[0]
708

709
            # Compile extended model
710
            functional_model = keras.Model(inputs=functional_inputs, outputs=hidden_states)
711
            model_out = functional_model.predict(model.dummy_inputs)  # Check we can pass inputs with the Keras API
712
            self.assertTrue(model_out is not None)
713
            with tempfile.TemporaryDirectory() as tmpdirname:
714
                functional_model.save(tmpdirname)  # Ensure we can save/export the whole functional model
715

716
    def test_keyword_and_dict_args(self):
717
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
718

719
        for model_class in self.all_model_classes:
720
            model = model_class(config)
721
            inputs = self._prepare_for_class(inputs_dict, model_class)
722

723
            outputs_dict = model(inputs)
724

725
            inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
726
            outputs_keywords = model(**inputs_keywords)
727
            output_dict = outputs_dict[0].numpy()
728
            output_keywords = outputs_keywords[0].numpy()
729

730
            self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
731

732
    def test_attention_outputs(self):
733
        if not self.has_attentions:
734
            self.skipTest(reason="Model does not output attentions")
735

736
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
737
        config.return_dict = True
738
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
739
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
740
        decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
741
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
742

743
        def check_decoder_attentions_output(outputs):
744
            out_len = len(outputs)
745
            self.assertEqual(min(out_len % 2, out_len % 5), 0)  # differentiation due to newly added cross_attentions
746
            decoder_attentions = outputs.decoder_attentions
747
            self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
748
            self.assertListEqual(
749
                list(decoder_attentions[0].shape[-3:]),
750
                [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
751
            )
752

753
        def check_encoder_attentions_output(outputs):
754
            attentions = [
755
                t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
756
            ]
757
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
758
            self.assertListEqual(
759
                list(attentions[0].shape[-3:]),
760
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
761
            )
762

763
        for model_class in self.all_model_classes:
764
            inputs_dict["output_attentions"] = True
765
            config.output_hidden_states = False
766
            model = model_class(config)
767
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
768
            out_len = len(outputs)
769
            self.assertEqual(config.output_hidden_states, False)
770
            check_encoder_attentions_output(outputs)
771

772
            if self.is_encoder_decoder:
773
                model = model_class(config)
774
                outputs = model(self._prepare_for_class(inputs_dict, model_class))
775
                self.assertEqual(config.output_hidden_states, False)
776
                check_decoder_attentions_output(outputs)
777

778
            # Check that output attentions can also be changed via the config
779
            del inputs_dict["output_attentions"]
780
            config.output_attentions = True
781
            model = model_class(config)
782
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
783
            self.assertEqual(config.output_hidden_states, False)
784
            check_encoder_attentions_output(outputs)
785

786
            # Check attention is always last and order is fine
787
            inputs_dict["output_attentions"] = True
788
            config.output_hidden_states = True
789
            model = model_class(config)
790
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
791

792
            self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
793
            self.assertEqual(model.config.output_hidden_states, True)
794
            check_encoder_attentions_output(outputs)
795

796
    def test_headmasking(self):
797
        if not self.test_head_masking:
798
            return
799

800
        random.Random().seed(42)
801
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
802
        random.Random().seed()
803

804
        inputs_dict["output_attentions"] = True
805
        config.output_hidden_states = True
806
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
807
        for model_class in self.all_model_classes:
808
            model = model_class(config=configs_no_init)
809

810
            # Prepare head_mask
811
            def prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
812
                if i == 0:
813
                    return tf.concat(
814
                        (tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0
815
                    )
816
                elif i == num_hidden_layers - 1:
817
                    return tf.concat(
818
                        (tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0
819
                    )
820
                else:
821
                    return tf.ones(attention_heads, dtype=tf.float32)
822

823
            head_mask = tf.stack(
824
                [
825
                    prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
826
                    for i in range(config.num_hidden_layers)
827
                ],
828
                0,
829
            )
830

831
            inputs = self._prepare_for_class(inputs_dict, model_class).copy()
832
            inputs["head_mask"] = head_mask
833
            if model.config.is_encoder_decoder:
834
                signature = inspect.signature(model.call)
835
                arg_names = [*signature.parameters.keys()]
836
                if "decoder_head_mask" in arg_names:  # necessary diferentiation because of T5 model
837
                    inputs["decoder_head_mask"] = head_mask
838
                if "cross_attn_head_mask" in arg_names:
839
                    inputs["cross_attn_head_mask"] = head_mask
840

841
            outputs = model(**inputs, return_dict=True)
842

843
            def check_attentions_validity(attentions):
844
                # Remove Nan
845
                for t in attentions:
846
                    self.assertLess(
847
                        (tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy()
848
                    )  # Check we don't have more than 25% nans (arbitrary)
849

850
                attentions = [
851
                    tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions
852
                ]  # remove them (the test is less complete)
853

854
                self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0)
855
                self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0)
856
                if len(attentions) > 2:  # encoder-decodere models have only 2 layers in each modules
857
                    self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0)
858
                self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0)
859
                self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0)
860

861
            if model.config.is_encoder_decoder:
862
                check_attentions_validity(outputs.encoder_attentions)
863
                check_attentions_validity(outputs.decoder_attentions)
864
                if "cross_attn_head_mask" in arg_names:
865
                    check_attentions_validity(outputs.cross_attentions)
866
            else:
867
                check_attentions_validity(outputs.attentions)
868

869
    def test_hidden_states_output(self):
870
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
871

872
        def check_hidden_states_output(config, inputs_dict, model_class):
873
            model = model_class(config)
874
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
875
            expected_num_layers = getattr(
876
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
877
            )
878

879
            if model.config.is_encoder_decoder:
880
                encoder_hidden_states = outputs.encoder_hidden_states
881
                decoder_hidden_states = outputs.decoder_hidden_states
882

883
                self.assertEqual(config.output_attentions, False)
884
                self.assertEqual(len(encoder_hidden_states), expected_num_layers)
885
                self.assertListEqual(
886
                    list(encoder_hidden_states[0].shape[-2:]),
887
                    [self.model_tester.seq_length, self.model_tester.hidden_size],
888
                )
889
                self.assertEqual(len(decoder_hidden_states), expected_num_layers)
890
                self.assertListEqual(
891
                    list(decoder_hidden_states[0].shape[-2:]),
892
                    [self.model_tester.seq_length, self.model_tester.hidden_size],
893
                )
894
            else:
895
                hidden_states = outputs.hidden_states
896
                self.assertEqual(config.output_attentions, False)
897
                self.assertEqual(len(hidden_states), expected_num_layers)
898
                self.assertListEqual(
899
                    list(hidden_states[0].shape[-2:]),
900
                    [self.model_tester.seq_length, self.model_tester.hidden_size],
901
                )
902

903
        for model_class in self.all_model_classes:
904
            inputs_dict["output_hidden_states"] = True
905
            check_hidden_states_output(config, inputs_dict, model_class)
906

907
            del inputs_dict["output_hidden_states"]
908
            config.output_hidden_states = True
909
            check_hidden_states_output(config, inputs_dict, model_class)
910

911
    def test_model_common_attributes(self):
912
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
913
        text_in_text_out_models = (
914
            get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
915
            + get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
916
            + get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
917
        )
918
        speech_in_text_out_models = get_values(TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING)
919

920
        for model_class in self.all_model_classes:
921
            model = model_class(config)
922
            self.assertIsInstance(model.get_input_embeddings(), keras.layers.Layer)
923

924
            legacy_text_in_text_out = model.get_lm_head() is not None
925
            if model_class in text_in_text_out_models or legacy_text_in_text_out:
926
                out_embeddings = model.get_output_embeddings()
927
                self.assertIsInstance(out_embeddings, keras.layers.Layer)
928
                bias = model.get_bias()
929
                if bias is not None:
930
                    self.assertIsInstance(bias, dict)
931
                    for _, v in bias.items():
932
                        self.assertIsInstance(v, tf.Variable)
933
            elif model_class in speech_in_text_out_models:
934
                out_embeddings = model.get_output_embeddings()
935
                self.assertIsInstance(out_embeddings, keras.layers.Layer)
936
                bias = model.get_bias()
937
                self.assertIsNone(bias)
938
            else:
939
                out_embeddings = model.get_output_embeddings()
940
                assert out_embeddings is None
941
                bias = model.get_bias()
942
                self.assertIsNone(bias)
943

944
    def test_determinism(self):
945
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
946

947
        for model_class in self.all_model_classes:
948
            model = model_class(config)
949
            first, second = (
950
                model(self._prepare_for_class(inputs_dict, model_class), training=False)[0],
951
                model(self._prepare_for_class(inputs_dict, model_class), training=False)[0],
952
            )
953
            out_1 = first.numpy()
954
            out_2 = second.numpy()
955
            out_1 = out_1[~np.isnan(out_1)]
956
            out_2 = out_2[~np.isnan(out_2)]
957
            max_diff = np.amax(np.abs(out_1 - out_2))
958
            self.assertLessEqual(max_diff, 1e-5)
959

960
    def test_model_outputs_equivalence(self):
961
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
962

963
        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
964
            tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
965
            dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
966

967
            def recursive_check(tuple_object, dict_object):
968
                if isinstance(tuple_object, (List, Tuple)):
969
                    for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
970
                        recursive_check(tuple_iterable_value, dict_iterable_value)
971
                elif tuple_object is None:
972
                    return
973
                else:
974
                    self.assertTrue(
975
                        all(tf.equal(tuple_object, dict_object)),
976
                        msg=(
977
                            "Tuple and dict output are not equal. Difference:"
978
                            f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
979
                        ),
980
                    )
981

982
                recursive_check(tuple_output, dict_output)
983

984
        for model_class in self.all_model_classes:
985
            model = model_class(config)
986

987
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
988
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
989
            check_equivalence(model, tuple_inputs, dict_inputs)
990

991
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
992
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
993
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
994

995
            if self.has_attentions:
996
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
997
                dict_inputs = self._prepare_for_class(inputs_dict, model_class)
998
                check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
999

1000
            # Not all models accept "labels" in the forward pass (yet :) )
1001
            if "labels" in inspect.signature(model.call).parameters.keys():
1002
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1003
                dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1004
                check_equivalence(model, tuple_inputs, dict_inputs)
1005

1006
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1007
                dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1008
                check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
1009

1010
                if self.has_attentions:
1011
                    tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1012
                    dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1013
                    check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
1014

1015
                    tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1016
                    dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1017
                    check_equivalence(
1018
                        model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
1019
                    )
1020

1021
    def test_inputs_embeds(self):
1022
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1023

1024
        for model_class in self.all_model_classes:
1025
            model = model_class(config)
1026

1027
            inputs = copy.deepcopy(inputs_dict)
1028

1029
            if not self.is_encoder_decoder:
1030
                input_ids = inputs["input_ids"]
1031
                del inputs["input_ids"]
1032
            else:
1033
                encoder_input_ids = inputs["input_ids"]
1034
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
1035
                del inputs["input_ids"]
1036
                inputs.pop("decoder_input_ids", None)
1037

1038
            if not self.is_encoder_decoder:
1039
                inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
1040
            else:
1041
                inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
1042
                inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
1043

1044
            inputs = self._prepare_for_class(inputs, model_class)
1045

1046
            model(inputs)
1047

1048
    def test_numpy_arrays_inputs(self):
1049
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1050

1051
        def prepare_numpy_arrays(inputs_dict):
1052
            inputs_np_dict = {}
1053
            for k, v in inputs_dict.items():
1054
                if tf.is_tensor(v):
1055
                    inputs_np_dict[k] = v.numpy()
1056
                else:
1057
                    inputs_np_dict[k] = np.array(k)
1058

1059
            return inputs_np_dict
1060

1061
        for model_class in self.all_model_classes:
1062
            model = model_class(config)
1063

1064
            inputs = self._prepare_for_class(inputs_dict, model_class)
1065
            inputs_np = prepare_numpy_arrays(inputs)
1066

1067
            output_for_dict_input = model(inputs_np)
1068
            output_for_kw_input = model(**inputs_np)
1069
            self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
1070

1071
    def test_valid_input_signature_and_dummies(self):
1072
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
1073
        for model_class in self.all_model_classes:
1074
            model = model_class(config)
1075
            call_args = inspect.signature(model.call).parameters
1076
            for key in model.input_signature:
1077
                self.assertIn(key, call_args)
1078
            for key in model.dummy_inputs:
1079
                self.assertIn(key, call_args)
1080

1081
    def test_resize_token_embeddings(self):
1082
        # TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
1083
        # keras.layers.Embedding
1084

1085
        if not self.test_resize_embeddings:
1086
            return
1087
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1088

1089
        def _get_word_embedding_weight(model, embedding_layer):
1090
            if isinstance(embedding_layer, keras.layers.Embedding):
1091
                # builds the embeddings layer
1092
                model.build_in_name_scope()
1093
                return embedding_layer.embeddings
1094
            else:
1095
                return model._get_word_embedding_weight(embedding_layer)
1096

1097
        for model_class in self.all_model_classes:
1098
            for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
1099
                # build the embeddings
1100
                model = model_class(config=copy.deepcopy(config))  # `resize_token_embeddings` mutates `config`
1101
                old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
1102
                old_bias = model.get_bias()
1103
                old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
1104
                # reshape the embeddings
1105
                model.resize_token_embeddings(size)
1106
                new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
1107
                new_bias = model.get_bias()
1108
                new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
1109

1110
                # check that the resized embeddings size matches the desired size.
1111
                assert_size = size if size is not None else config.vocab_size
1112
                self.assertEqual(new_input_embeddings.shape[0], assert_size)
1113

1114
                # check that weights remain the same after resizing
1115
                models_equal = True
1116
                for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
1117
                    if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
1118
                        models_equal = False
1119
                self.assertTrue(models_equal)
1120

1121
                if old_bias is not None and new_bias is not None:
1122
                    for old_weight, new_weight in zip(old_bias.values(), new_bias.values()):
1123
                        self.assertEqual(new_weight.shape[-1], assert_size)
1124

1125
                        models_equal = True
1126
                        for p1, p2 in zip(tf.squeeze(old_weight), tf.squeeze(new_weight)):
1127
                            if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
1128
                                models_equal = False
1129
                        self.assertTrue(models_equal)
1130

1131
                if old_output_embeddings is not None and new_output_embeddings is not None:
1132
                    self.assertEqual(new_output_embeddings.shape[0], assert_size)
1133
                    self.assertEqual(new_output_embeddings.shape[1], old_output_embeddings.shape[1])
1134

1135
                    models_equal = True
1136
                    for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
1137
                        if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
1138
                            models_equal = False
1139
                    self.assertTrue(models_equal)
1140

1141
    # TODO (Joao): this test is not slow, but it's tagged as such to keep track of failures on the scheduled CI runs,
1142
    # while passing push CI. Fix the underlying issues and remove the tag.
1143
    @slow
1144
    def test_save_load_after_resize_token_embeddings(self):
1145
        if not self.test_resize_embeddings:
1146
            return
1147
        config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1148

1149
        for model_class in self.all_model_classes:
1150
            # create a model with resized (expended) embeddings
1151
            new_tokens_size = 10
1152
            old_total_size = config.vocab_size
1153
            new_total_size = old_total_size + new_tokens_size
1154
            model = model_class(config=copy.deepcopy(config))  # `resize_token_embeddings` mutates `config`
1155
            model.build_in_name_scope()
1156
            model.resize_token_embeddings(new_total_size)
1157

1158
            # fetch the output for an input exclusively made of new members of the vocabulary
1159
            inputs_dict = copy.deepcopy(original_inputs_dict)
1160
            ids_feat_name = None
1161
            if "input_ids" in inputs_dict:
1162
                ids_feat_name = "input_ids"
1163
            elif "decoder_input_ids" in inputs_dict:
1164
                ids_feat_name = "decoder_input_ids"
1165
            else:
1166
                assert False, "No input ids feature found in the inputs dict"
1167

1168
            new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size)
1169
            new_vocab_input_ids += old_total_size
1170
            inputs_dict[ids_feat_name] = new_vocab_input_ids
1171
            if "input_ids" in inputs_dict:
1172
                inputs_dict["input_ids"] = new_vocab_input_ids
1173
            if "decoder_input_ids" in inputs_dict:
1174
                inputs_dict["decoder_input_ids"] = new_vocab_input_ids
1175
            prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
1176
            outputs = model(**prepared_inputs)
1177

1178
            # save and load the model
1179
            with tempfile.TemporaryDirectory() as tmpdirname:
1180
                model.save_pretrained(tmpdirname, saved_model=False)
1181
                model = model_class.from_pretrained(tmpdirname)
1182
                restored_model_outputs = model(**prepared_inputs)
1183

1184
                # check that the output for the restored model is the same
1185
                self.assert_outputs_same(restored_model_outputs, outputs)
1186

1187
    @unittest.skipIf(
1188
        not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
1189
        reason="This test always passes on CPU.",
1190
    )
1191
    def test_embeddings_out_of_bounds_raise_exception(self):
1192
        # TF embeddings layers don't raise an exception when an index is out of bounds on GPU, so we manually raise it.
1193
        # This test should only fail on GPU for models where we haven't added the safety check.
1194
        if not self.test_resize_embeddings:
1195
            return
1196
        config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1197

1198
        for model_class in self.all_model_classes:
1199
            model = model_class(config=config)
1200
            inputs_dict = copy.deepcopy(original_inputs_dict)
1201
            if "input_ids" in inputs_dict:
1202
                inputs_dict["input_ids"] = inputs_dict["input_ids"] * int(1e9)
1203
            if "decoder_input_ids" in inputs_dict:
1204
                inputs_dict["decoder_input_ids"] = inputs_dict["decoder_input_ids"] * int(1e9)
1205
            prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
1206
            with self.assertRaises(tf.errors.InvalidArgumentError):
1207
                model(**prepared_inputs)
1208

1209
    def test_lm_head_model_random_no_beam_search_generate(self):
1210
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1211
        input_ids = inputs_dict.get("input_ids", None)
1212

1213
        # iterate over all generative models
1214
        for model_class in self.all_generative_model_classes:
1215
            model = model_class(config)
1216

1217
            if config.bos_token_id is None:
1218
                # if bos token id is not defined model needs input_ids
1219
                with self.assertRaises(ValueError):
1220
                    model.generate(do_sample=True, max_length=5)
1221
                # num_return_sequences = 1
1222
                self._check_generated_ids(model.generate(input_ids, do_sample=True))
1223
            elif model_class.__name__ not in ["TFSpeech2TextForConditionalGeneration"]:
1224
                # Models with non-text inputs won't work here; num_return_sequences = 1
1225
                self._check_generated_ids(model.generate(do_sample=True, max_length=5))
1226

1227
            with self.assertRaises(ValueError):
1228
                # generating multiple sequences when no beam search generation
1229
                # is not allowed as it would always generate the same sequences
1230
                model.generate(input_ids, do_sample=False, num_return_sequences=2)
1231

1232
            # num_return_sequences > 1, sample
1233
            self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
1234

1235
            # check bad words tokens language generation
1236
            # create list of 1-seq bad token and list of 2-seq of bad tokens
1237
            bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
1238
            output_tokens = model.generate(
1239
                input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
1240
            )
1241
            # only count generated tokens
1242
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
1243
            self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
1244

1245
    def test_lm_head_model_no_beam_search_generate_dict_outputs(self):
1246
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1247
        input_ids = inputs_dict.get("input_ids", None)
1248
        if input_ids is None:
1249
            input_ids = inputs_dict.get("input_features", None)
1250

1251
        # iterate over all generative models
1252
        for model_class in self.all_generative_model_classes:
1253
            model = model_class(config)
1254
            output_greedy = model.generate(
1255
                input_ids,
1256
                do_sample=False,
1257
                output_scores=True,
1258
                output_hidden_states=True,
1259
                output_attentions=True,
1260
                return_dict_in_generate=True,
1261
            )
1262
            output_sample = model.generate(
1263
                input_ids,
1264
                do_sample=True,
1265
                output_scores=True,
1266
                output_hidden_states=True,
1267
                output_attentions=True,
1268
                return_dict_in_generate=True,
1269
            )
1270

1271
            if model.config.is_encoder_decoder:
1272
                self.assertIsInstance(output_greedy, TFGreedySearchEncoderDecoderOutput)
1273
                self.assertIsInstance(output_sample, TFSampleEncoderDecoderOutput)
1274
            else:
1275
                self.assertIsInstance(output_greedy, TFGreedySearchDecoderOnlyOutput)
1276
                self.assertIsInstance(output_sample, TFSampleDecoderOnlyOutput)
1277

1278
    def test_lm_head_model_random_beam_search_generate(self):
1279
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1280
        input_ids = inputs_dict.get("input_ids", None)
1281

1282
        for model_class in self.all_generative_model_classes:
1283
            model = model_class(config)
1284

1285
            if config.bos_token_id is None:
1286
                # if bos token id is not defined model needs input_ids, num_return_sequences = 1
1287
                self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
1288
            else:
1289
                # num_return_sequences = 1
1290
                self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
1291

1292
            with self.assertRaises(ValueError):
1293
                # generating more sequences than having beams leads is not possible
1294
                model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
1295

1296
            # num_return_sequences > 1, sample
1297
            self._check_generated_ids(
1298
                model.generate(
1299
                    input_ids,
1300
                    do_sample=True,
1301
                    num_beams=2,
1302
                    num_return_sequences=2,
1303
                )
1304
            )
1305
            # num_return_sequences > 1, greedy
1306
            self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
1307

1308
            # check bad words tokens language generation
1309
            # create list of 1-seq bad token and list of 2-seq of bad tokens
1310
            bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
1311
            output_tokens = model.generate(
1312
                input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
1313
            )
1314
            # only count generated tokens
1315
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
1316
            self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
1317

1318
    def test_lm_head_model_beam_search_generate_dict_outputs(self):
1319
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1320
        input_ids = inputs_dict.get("input_ids", None)
1321
        if input_ids is None:
1322
            input_ids = inputs_dict.get("input_features", None)
1323

1324
        # iterate over all generative models
1325
        for model_class in self.all_generative_model_classes:
1326
            model = model_class(config)
1327
            output_beam_search = model.generate(
1328
                input_ids,
1329
                num_beams=2,
1330
                do_sample=False,
1331
                output_scores=True,
1332
                output_hidden_states=True,
1333
                output_attentions=True,
1334
                return_dict_in_generate=True,
1335
            )
1336
            output_beam_sample = model.generate(
1337
                input_ids,
1338
                num_beams=2,
1339
                do_sample=True,
1340
                output_scores=True,
1341
                output_hidden_states=True,
1342
                output_attentions=True,
1343
                return_dict_in_generate=True,
1344
            )
1345

1346
            if model.config.is_encoder_decoder:
1347
                self.assertIsInstance(output_beam_search, TFBeamSearchEncoderDecoderOutput)
1348
                self.assertIsInstance(output_beam_sample, TFBeamSampleEncoderDecoderOutput)
1349
            else:
1350
                self.assertIsInstance(output_beam_search, TFBeamSearchDecoderOnlyOutput)
1351
                self.assertIsInstance(output_beam_sample, TFBeamSampleDecoderOnlyOutput)
1352

1353
    def test_loss_computation(self):
1354
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1355
        for model_class in self.all_model_classes:
1356
            model = model_class(config)
1357
            # The number of elements in the loss should be the same as the number of elements in the label
1358
            prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
1359
            added_label_names = sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)
1360
            if not added_label_names:
1361
                continue  # This test is only for models with easily-separable labels
1362
            added_label = prepared_for_class[added_label_names[0]]
1363
            expected_loss_size = added_label.shape.as_list()[:1]
1364

1365
            # Test that model correctly compute the loss with kwargs
1366
            prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
1367
            possible_input_names = {"input_ids", "pixel_values", "input_features", "input_values"}
1368
            input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
1369
            model_input = prepared_for_class.pop(input_name)
1370

1371
            outputs = model(model_input, **prepared_for_class)
1372
            if not isinstance(outputs, ModelOutput) or not hasattr(outputs, "loss"):
1373
                continue
1374

1375
            loss = outputs.loss
1376
            self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
1377

1378
            # Test that model correctly compute the loss when we mask some positions
1379
            prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
1380
            possible_input_names = {"input_ids", "pixel_values", "input_features", "input_values"}
1381
            input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
1382
            model_input = prepared_for_class.pop(input_name)
1383
            if "labels" in prepared_for_class:
1384
                labels = prepared_for_class["labels"].numpy()
1385
                if len(labels.shape) > 1 and labels.shape[1] != 1:
1386
                    labels[0] = -100
1387
                    prepared_for_class["labels"] = tf.convert_to_tensor(labels)
1388
                    loss = model(model_input, **prepared_for_class)[0]
1389
                    self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
1390
                    self.assertTrue(not np.any(np.isnan(loss.numpy())))
1391

1392
            # Test that model correctly compute the loss with a dict
1393
            prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
1394
            loss = model(prepared_for_class)[0]
1395
            self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
1396

1397
            # Test that model correctly compute the loss with a tuple
1398
            prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
1399

1400
            # Get keys that were added with the _prepare_for_class function
1401
            label_keys = prepared_for_class.keys() - inputs_dict.keys()
1402
            signature = inspect.signature(model.call).parameters
1403
            signature_names = list(signature.keys())
1404

1405
            # Create a dictionary holding the location of the tensors in the tuple
1406
            tuple_index_mapping = {0: input_name}
1407
            for label_key in label_keys:
1408
                label_key_index = signature_names.index(label_key)
1409
                tuple_index_mapping[label_key_index] = label_key
1410
            sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
1411
            # Initialize a list with their default values, update the values and convert to a tuple
1412
            list_input = []
1413

1414
            for name in signature_names:
1415
                if name != "kwargs":
1416
                    list_input.append(signature[name].default)
1417

1418
            for index, value in sorted_tuple_index_mapping:
1419
                list_input[index] = prepared_for_class[value]
1420

1421
            tuple_input = tuple(list_input)
1422

1423
            # Send to model
1424
            loss = model(tuple_input[:-1])[0]
1425

1426
            self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
1427

1428
    def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
1429
        self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
1430

1431
    @slow
1432
    def test_keras_fit(self):
1433
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1434
        for model_class in self.all_model_classes:
1435
            model = model_class(config)
1436
            # Test that model correctly compute the loss with kwargs
1437
            prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
1438
            # We also remove "return_loss" as this is covered by the train_step when using fit()
1439
            prepared_for_class = {
1440
                key: val
1441
                for key, val in prepared_for_class.items()
1442
                if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "return_loss")
1443
            }
1444
            if "labels" in prepared_for_class and "decoder_input_ids" in prepared_for_class:
1445
                del prepared_for_class["decoder_input_ids"]
1446

1447
            accuracy_classes = [
1448
                "ForPreTraining",
1449
                "ForCausalLM",
1450
                "ForMaskedLM",
1451
                "ForQuestionAnswering",
1452
                "ForMultipleChoice",
1453
                "ForSequenceClassification",
1454
                "ForTokenClassification",
1455
                "ForNextSentencePrediction",
1456
                "LMHeadModel",
1457
            ]
1458
            for accuracy_class in accuracy_classes:
1459
                if model.__class__.__name__.endswith(accuracy_class):
1460
                    metrics = [keras.metrics.SparseCategoricalAccuracy()]
1461
                    break
1462
            else:
1463
                metrics = []
1464

1465
            if hasattr(self.model_tester, "batch_size"):
1466
                sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
1467
            else:
1468
                sample_weight = None
1469
            # Build the model so we can get some constant weights and check outputs
1470
            outputs = model(prepared_for_class)
1471
            if getattr(outputs, "loss", None) is None:
1472
                continue
1473
            model_weights = model.get_weights()
1474

1475
            # Run eagerly to save some expensive compilation times
1476
            model.compile(optimizer=keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
1477
            # Make sure the model fits without crashing regardless of where we pass the labels
1478
            history1 = model.fit(
1479
                prepared_for_class,
1480
                validation_data=prepared_for_class,
1481
                sample_weight=sample_weight,
1482
                steps_per_epoch=1,
1483
                validation_steps=1,
1484
                shuffle=False,
1485
            )
1486
            val_loss1 = history1.history["val_loss"][0]
1487
            self.assertTrue(not isnan(val_loss1))
1488
            accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
1489

1490
            possible_label_cols = {
1491
                "labels",
1492
                "label",
1493
                "label_ids",
1494
                "start_positions",
1495
                "start_position",
1496
                "end_positions",
1497
                "end_position",
1498
                "next_sentence_label",
1499
            }
1500
            label_names = possible_label_cols.intersection(set(prepared_for_class))
1501
            if len(label_names) == 0:
1502
                # The next tests only make sense for models with separate inputs and labels, and do not make
1503
                # sense for models that don't clearly distinguish between the two (e.g. CLIP)
1504
                return
1505
            labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
1506
            inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
1507
            self.assertGreater(len(inputs_minus_labels), 0)
1508

1509
            # We reinitialize the model here even though our learning rate was zero
1510
            # because BatchNorm updates weights by means other than gradient descent.
1511
            model.set_weights(model_weights)
1512

1513
            history2 = model.fit(
1514
                inputs_minus_labels,
1515
                labels,
1516
                validation_data=(inputs_minus_labels, labels),
1517
                sample_weight=sample_weight,
1518
                steps_per_epoch=1,
1519
                validation_steps=1,
1520
                shuffle=False,
1521
            )
1522
            val_loss2 = history2.history["val_loss"][0]
1523
            self.assertTrue(not isnan(val_loss2))
1524
            accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
1525
            self.check_keras_fit_results(val_loss1, val_loss2)
1526
            self.assertEqual(history1.history.keys(), history2.history.keys())
1527
            for key in history1.history.keys():
1528
                if not key.startswith("val_"):
1529
                    self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
1530
            if metrics:
1531
                self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
1532

1533
    def test_int_support(self):
1534
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1535
        for model_class in self.all_model_classes:
1536
            prepared_for_class = self._prepare_for_class(
1537
                inputs_dict.copy(),
1538
                model_class,
1539
                return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
1540
            )
1541
            if not any(
1542
                tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)
1543
            ):
1544
                return  # No integer inputs means no need for this test
1545

1546
            prepared_for_class = {
1547
                key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
1548
                for key, tensor in prepared_for_class.items()
1549
            }
1550
            model = model_class(config)
1551
            model(**prepared_for_class)  # No assertion, we're just checking this doesn't throw an error
1552
            int32_prepared_for_class = {
1553
                key: tf.cast(tensor, tf.int32) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
1554
                for key, tensor in prepared_for_class.items()
1555
            }
1556
            model(**int32_prepared_for_class)  # No assertion, we're just checking this doesn't throw an error
1557

1558
            # After testing that the model accepts all int inputs, confirm that its dummies are int32
1559
            for key, tensor in model.dummy_inputs.items():
1560
                self.assertTrue(
1561
                    isinstance(tensor, tf.Tensor) or keras.backend.is_keras_tensor(tensor),
1562
                    "Dummy inputs should be tf.Tensor!",
1563
                )
1564
                if tensor.dtype.is_integer:
1565
                    self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")
1566

1567
            # Also confirm that the input_signature uses int32
1568
            for key, tensor_spec in model.input_signature.items():
1569
                if tensor_spec.dtype.is_integer:
1570
                    self.assertTrue(tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!")
1571

1572
    def test_generate_with_headmasking(self):
1573
        attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
1574
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1575

1576
        for model_class in self.all_generative_model_classes:
1577
            model = model_class(config)
1578

1579
            # We want to test only encoder-decoder models
1580
            if not config.is_encoder_decoder:
1581
                continue
1582

1583
            head_masking = {
1584
                "head_mask": tf.zeros((config.encoder_layers, config.encoder_attention_heads)),
1585
                "decoder_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
1586
                "cross_attn_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
1587
            }
1588

1589
            signature = inspect.signature(model.call)
1590
            if set(head_masking.keys()) < {*signature.parameters.keys()}:
1591
                continue
1592

1593
            for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
1594
                out = model.generate(
1595
                    inputs_dict["input_ids"],
1596
                    num_beams=1,
1597
                    max_length=inputs_dict["input_ids"] + 5,
1598
                    output_attentions=True,
1599
                    return_dict_in_generate=True,
1600
                    **{name: mask},
1601
                )
1602
                # We check the state of decoder_attentions and cross_attentions just from the last step
1603
                attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
1604
                self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
1605

1606
    def test_load_with_mismatched_shapes(self):
1607
        if not self.test_mismatched_shapes:
1608
            return
1609
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1610

1611
        for model_class in self.all_model_classes:
1612
            if model_class not in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
1613
                continue
1614

1615
            with self.subTest(msg=f"Testing {model_class}"):
1616
                with tempfile.TemporaryDirectory() as tmp_dir:
1617
                    model = model_class(config)
1618
                    inputs = self._prepare_for_class(inputs_dict, model_class)
1619
                    _ = model(**inputs)
1620
                    model.save_pretrained(tmp_dir)
1621

1622
                    # Fails when we don't set ignore_mismatched_sizes=True
1623
                    with self.assertRaises(ValueError):
1624
                        new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
1625
                    with self.assertRaises(ValueError):
1626
                        new_model_without_prefix = TFAutoModel.from_pretrained(tmp_dir, vocab_size=10)
1627

1628
                    logger = logging.get_logger("transformers.modeling_tf_utils")
1629
                    with CaptureLogger(logger) as cl:
1630
                        new_model = TFAutoModelForSequenceClassification.from_pretrained(
1631
                            tmp_dir, num_labels=42, ignore_mismatched_sizes=True
1632
                        )
1633
                    self.assertIn("the shapes did not match", cl.out)
1634

1635
                    logits = new_model(**inputs).logits
1636
                    self.assertEqual(logits.shape[1], 42)
1637

1638
                    with CaptureLogger(logger) as cl:
1639
                        new_model_without_prefix = TFAutoModel.from_pretrained(
1640
                            tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
1641
                        )
1642
                    self.assertIn("the shapes did not match", cl.out)
1643

1644
                    # Although Tf models always have a prefix pointing to `MainLayer`,
1645
                    # we still add this "without prefix" test to keep a consistency between tf and pt tests.
1646
                    input_ids = ids_tensor((2, 8), 10)
1647
                    if self.is_encoder_decoder:
1648
                        new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
1649
                    else:
1650
                        new_model_without_prefix(input_ids)
1651

1652
    def test_model_main_input_name(self):
1653
        for model_class in self.all_model_classes:
1654
            model_signature = inspect.signature(getattr(model_class, "call"))
1655
            # The main input is the name of the argument after `self`
1656
            observed_main_input_name = list(model_signature.parameters.keys())[1]
1657
            self.assertEqual(model_class.main_input_name, observed_main_input_name)
1658

1659
    def test_dataset_conversion(self):
1660
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1661
        for model_class in self.all_model_classes:
1662
            model = model_class(config)
1663
            tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
1664
            if "labels" in tf_inputs_dict:
1665
                return  # This is some kinda funky decoder model that needs labels in its forward pass
1666
            tf_inputs_dict = {
1667
                key: val
1668
                for key, val in tf_inputs_dict.items()
1669
                if "head_mask" not in key and isinstance(val, tf.Tensor)
1670
            }
1671
            tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0]  # Use a random other tensor
1672
            input_dataset = Dataset.from_dict(tf_inputs_dict)
1673
            tf_dataset = model.prepare_tf_dataset(
1674
                input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
1675
            )
1676
            test_batch = next(iter(tf_dataset))
1677
            if isinstance(test_batch, tf.Tensor):
1678
                self.assertEqual(len(test_batch), len(input_dataset))  # Assert we didn't lose any data
1679
            elif isinstance(test_batch, dict):
1680
                # Assert we discarded the unwanted extra column but kept everything else
1681
                self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
1682
                self.assertNotIn("extra_unwanted_column", test_batch)
1683
                for tensor in test_batch.values():
1684
                    self.assertTrue(isinstance(tensor, tf.Tensor))
1685
                    self.assertEqual(len(tensor), len(input_dataset))  # Assert we didn't lose any data
1686
            model(test_batch, training=False)
1687

1688
            if "labels" in inspect.signature(model_class.call).parameters.keys():
1689
                tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
1690
                if "labels" not in tf_inputs_dict:
1691
                    return  # This model isn't giving us labels after all, don't try training with it
1692
                tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
1693
                tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0]  # Use a random other tensor
1694
                input_dataset = Dataset.from_dict(tf_inputs_dict)
1695
                tf_dataset = model.prepare_tf_dataset(
1696
                    input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
1697
                )
1698
                test_batch, test_batch_labels = next(iter(tf_dataset))
1699
                self.assertGreater(len(test_batch_labels), 0)  # Assert the labels are present
1700
                feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
1701
                label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
1702
                # Assert we discarded the unwanted extra column but kept everything else
1703
                self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
1704
                if isinstance(test_batch, dict):
1705
                    self.assertNotIn("extra_unwanted_column", test_batch)
1706
                if isinstance(test_batch_labels, dict):
1707
                    self.assertNotIn("extra_unwanted_column", test_batch_labels)
1708
                model.compile(optimizer="sgd", run_eagerly=True)
1709
                model.train_on_batch(test_batch, test_batch_labels)
1710

1711
    def _test_xla_generate(self, **generate_kwargs):
1712
        def _generate_and_check_results(model, inputs_dict):
1713
            if "input_ids" in inputs_dict:
1714
                inputs = inputs_dict["input_ids"]
1715
                # make sure there are no pad tokens in prompt, which may trigger unwanted behavior
1716
                if model.generation_config.pad_token_id is not None:
1717
                    if config.pad_token_id == 0:
1718
                        new_pad_token = model.generation_config.pad_token_id + 1
1719
                    else:
1720
                        new_pad_token = model.generation_config.pad_token_id - 1
1721
                else:
1722
                    new_pad_token = None
1723
                inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token)
1724
            elif "input_features" in inputs_dict:
1725
                inputs = inputs_dict["input_features"]
1726
            else:
1727
                raise ValueError("No valid generate input found in inputs_dict")
1728

1729
            generated = model.generate(inputs, **generate_kwargs).numpy()
1730
            generate_xla = tf.function(model.generate, jit_compile=True)
1731
            generated_xla = generate_xla(inputs, **generate_kwargs).numpy()
1732

1733
            # Due to numerical instability, let's fail the test only if there are more than 10% of input sequences give
1734
            # different outputs between XLA and non-XLA versions. If there are less than 10 examples, let's be strict
1735
            # and not allow any difference.
1736
            diff = [[], []]
1737
            for _generated, _generated_xla in zip(generated.tolist(), generated_xla.tolist()):
1738
                if _generated != _generated_xla:
1739
                    diff[0].append(_generated)
1740
                    diff[1].append(_generated_xla)
1741
            ratio = len(diff[0]) / len(generated)
1742
            if ratio > 0.1 or (len(diff[0]) > 0 and len(generated) < 10):
1743
                self.assertListEqual(diff[0], diff[1])
1744

1745
        for model_class in self.all_generative_model_classes:
1746
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1747
            config.eos_token_id = None  # Generate until max length
1748
            config.do_sample = False
1749

1750
            # fix config for models with additional sequence-length limiting settings
1751
            for var_name in ["max_position_embeddings", "max_target_positions"]:
1752
                attr = getattr(config, var_name, None)
1753
                if attr is not None and attr < generate_kwargs["max_new_tokens"]:
1754
                    try:
1755
                        setattr(config, var_name, generate_kwargs["max_new_tokens"])
1756
                    except NotImplementedError:
1757
                        # xlnet will raise an exception when trying to set
1758
                        # max_position_embeddings.
1759
                        pass
1760

1761
            model = model_class(config)
1762

1763
            if model.supports_xla_generation:
1764
                _generate_and_check_results(model, inputs_dict)
1765
            else:
1766
                with self.assertRaises(ValueError):
1767
                    _generate_and_check_results(model, inputs_dict)
1768

1769
    def test_xla_generate_fast(self):
1770
        """
1771
        Basic quick test for generate-compatible classes that confirms that XLA-generated tokens are the same as their
1772
        non XLA counterparts.
1773

1774
        Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
1775
        """
1776
        self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=3)
1777

1778
    @slow
1779
    def test_xla_generate_contrastive(self):
1780
        """
1781
        Slow and challenging version of `test_xla_generate_fast` for contrastive search -- contrastive search directly
1782
        manipulates the model cache and other outputs, and this test ensures that they are in a valid format that is
1783
        also supported by XLA.
1784

1785
        Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
1786
        """
1787
        self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=16, penalty_alpha=0.5, top_k=4)
1788

1789
    @slow
1790
    def test_xla_generate_slow(self):
1791
        """
1792
        Slow and challenging version of `test_xla_generate_fast` -- this test asks for several long sequences using
1793
        beam search, with and without XLA. The two outputs should match, and a failure in this test indicates that the
1794
        model may need further analysis if it is to be used for XLA generation.
1795

1796
        Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
1797
        """
1798
        self._test_xla_generate(num_beams=8, num_return_sequences=2, max_new_tokens=128)
1799

1800
    def _generate_random_bad_tokens(self, num_bad_tokens, model):
1801
        # special tokens cannot be bad tokens
1802
        special_tokens = []
1803
        if model.config.bos_token_id is not None:
1804
            special_tokens.append(model.config.bos_token_id)
1805
        if model.config.pad_token_id is not None:
1806
            special_tokens.append(model.config.pad_token_id)
1807
        if model.config.eos_token_id is not None:
1808
            special_tokens.append(model.config.eos_token_id)
1809

1810
        # create random bad tokens that are not special tokens
1811
        bad_tokens = []
1812
        while len(bad_tokens) < num_bad_tokens:
1813
            token = tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), 0).numpy()[0]
1814
            if token not in special_tokens:
1815
                bad_tokens.append(token)
1816
        return bad_tokens
1817

1818
    def _check_generated_ids(self, output_ids):
1819
        for token_id in output_ids[0].numpy().tolist():
1820
            self.assertGreaterEqual(token_id, 0)
1821
            self.assertLess(token_id, self.model_tester.vocab_size)
1822

1823
    def _check_match_tokens(self, generated_ids, bad_words_ids):
1824
        # for all bad word tokens
1825
        for bad_word_ids in bad_words_ids:
1826
            # for all slices in batch
1827
            for generated_ids_slice in generated_ids:
1828
                # for all word idx
1829
                for i in range(len(bad_word_ids), len(generated_ids_slice)):
1830
                    # if tokens match
1831
                    if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
1832
                        return True
1833
        return False
1834

1835

1836
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
1837
    """Creates a random int32 tensor of the shape within the vocab size."""
1838
    if rng is None:
1839
        rng = random.Random()
1840

1841
    total_dims = 1
1842
    for dim in shape:
1843
        total_dims *= dim
1844

1845
    values = []
1846
    for _ in range(total_dims):
1847
        values.append(rng.randint(0, vocab_size - 1))
1848

1849
    output = tf.constant(values, shape=shape, dtype=dtype if dtype is not None else tf.int32)
1850

1851
    return output
1852

1853

1854
def random_attention_mask(shape, rng=None, name=None, dtype=None):
1855
    attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
1856
    # make sure that at least one token is attended to for each batch
1857
    attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1)
1858
    return attn_mask
1859

1860

1861
def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
1862
    """Creates a random float32 tensor"""
1863
    if rng is None:
1864
        rng = random.Random()
1865

1866
    total_dims = 1
1867
    for dim in shape:
1868
        total_dims *= dim
1869

1870
    values = []
1871
    for _ in range(total_dims):
1872
        values.append(rng.random() * scale)
1873

1874
    return tf.reshape(tf.constant(values, dtype=dtype if dtype is not None else tf.float32), shape=shape)
1875

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.