transformers

Форк
0
/
test_modeling_x_clip.py 
730 строк · 27.3 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch XCLIP model. """
16

17

18
import inspect
19
import os
20
import tempfile
21
import unittest
22

23
import numpy as np
24
from huggingface_hub import hf_hub_download
25

26
from transformers import XCLIPConfig, XCLIPTextConfig, XCLIPVisionConfig
27
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
28
from transformers.utils import is_torch_available, is_vision_available
29

30
from ...test_configuration_common import ConfigTester
31
from ...test_modeling_common import (
32
    ModelTesterMixin,
33
    _config_zero_init,
34
    floats_tensor,
35
    ids_tensor,
36
    random_attention_mask,
37
)
38
from ...test_pipeline_mixin import PipelineTesterMixin
39

40

41
if is_torch_available():
42
    import torch
43
    from torch import nn
44

45
    from transformers import XCLIPModel, XCLIPTextModel, XCLIPVisionModel
46
    from transformers.models.x_clip.modeling_x_clip import XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST
47

48

49
if is_vision_available():
50
    from transformers import XCLIPProcessor
51

52

53
class XCLIPVisionModelTester:
54
    def __init__(
55
        self,
56
        parent,
57
        batch_size=8,
58
        image_size=30,
59
        patch_size=2,
60
        num_channels=3,
61
        num_frames=8,  # important; the batch size * time must be divisible by the number of frames
62
        is_training=True,
63
        hidden_size=32,
64
        num_hidden_layers=2,
65
        num_attention_heads=4,
66
        intermediate_size=37,
67
        mit_hidden_size=64,
68
        dropout=0.1,
69
        attention_dropout=0.1,
70
        initializer_range=0.02,
71
        scope=None,
72
    ):
73
        self.parent = parent
74
        self.batch_size = batch_size
75
        self.image_size = image_size
76
        self.patch_size = patch_size
77
        self.num_channels = num_channels
78
        self.num_frames = num_frames
79
        self.is_training = is_training
80
        self.hidden_size = hidden_size
81
        self.num_hidden_layers = num_hidden_layers
82
        self.num_attention_heads = num_attention_heads
83
        self.intermediate_size = intermediate_size
84
        self.mit_hidden_size = mit_hidden_size
85
        self.dropout = dropout
86
        self.attention_dropout = attention_dropout
87
        self.initializer_range = initializer_range
88
        self.scope = scope
89

90
        # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
91
        num_patches = (image_size // patch_size) ** 2
92
        self.seq_length = num_patches + 1
93

94
    def prepare_config_and_inputs(self):
95
        pixel_values = floats_tensor(
96
            [self.batch_size * self.num_frames, self.num_channels, self.image_size, self.image_size]
97
        )
98
        config = self.get_config()
99

100
        return config, pixel_values
101

102
    def get_config(self):
103
        return XCLIPVisionConfig(
104
            image_size=self.image_size,
105
            patch_size=self.patch_size,
106
            num_channels=self.num_channels,
107
            num_frames=self.num_frames,
108
            hidden_size=self.hidden_size,
109
            num_hidden_layers=self.num_hidden_layers,
110
            num_attention_heads=self.num_attention_heads,
111
            intermediate_size=self.intermediate_size,
112
            mit_hidden_size=self.mit_hidden_size,
113
            dropout=self.dropout,
114
            attention_dropout=self.attention_dropout,
115
            initializer_range=self.initializer_range,
116
        )
117

118
    def create_and_check_model(self, config, pixel_values):
119
        model = XCLIPVisionModel(config=config)
120
        model.to(torch_device)
121
        model.eval()
122
        with torch.no_grad():
123
            result = model(pixel_values)
124
        # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
125
        image_size = (self.image_size, self.image_size)
126
        patch_size = (self.patch_size, self.patch_size)
127
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
128
        self.parent.assertEqual(
129
            result.last_hidden_state.shape, (self.batch_size * self.num_frames, num_patches + 1, self.hidden_size)
130
        )
131
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size * self.num_frames, self.hidden_size))
132

133
    def prepare_config_and_inputs_for_common(self):
134
        config_and_inputs = self.prepare_config_and_inputs()
135
        config, pixel_values = config_and_inputs
136
        inputs_dict = {"pixel_values": pixel_values}
137
        return config, inputs_dict
138

139

140
@require_torch
141
class XCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
142
    """
143
    Here we also overwrite some of the tests of test_modeling_common.py, as X-CLIP does not use input_ids, inputs_embeds,
144
    attention_mask and seq_length.
145
    """
146

147
    all_model_classes = (XCLIPVisionModel,) if is_torch_available() else ()
148
    fx_compatible = False
149
    test_pruning = False
150
    test_resize_embeddings = False
151
    test_head_masking = False
152

153
    def setUp(self):
154
        self.model_tester = XCLIPVisionModelTester(self)
155
        self.config_tester = ConfigTester(
156
            self, config_class=XCLIPVisionConfig, has_text_modality=False, hidden_size=37
157
        )
158

159
    def test_config(self):
160
        self.config_tester.run_common_tests()
161

162
    @unittest.skip(reason="X-CLIP does not use inputs_embeds")
163
    def test_inputs_embeds(self):
164
        pass
165

166
    def test_model_common_attributes(self):
167
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
168

169
        for model_class in self.all_model_classes:
170
            model = model_class(config)
171
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
172
            x = model.get_output_embeddings()
173
            self.assertTrue(x is None or isinstance(x, nn.Linear))
174

175
    def test_forward_signature(self):
176
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
177

178
        for model_class in self.all_model_classes:
179
            model = model_class(config)
180
            signature = inspect.signature(model.forward)
181
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
182
            arg_names = [*signature.parameters.keys()]
183

184
            expected_arg_names = ["pixel_values"]
185
            self.assertListEqual(arg_names[:1], expected_arg_names)
186

187
    def test_model(self):
188
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
189
        self.model_tester.create_and_check_model(*config_and_inputs)
190

191
    def test_training(self):
192
        pass
193

194
    def test_training_gradient_checkpointing(self):
195
        pass
196

197
    @unittest.skip(
198
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
199
    )
200
    def test_training_gradient_checkpointing_use_reentrant(self):
201
        pass
202

203
    @unittest.skip(
204
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
205
    )
206
    def test_training_gradient_checkpointing_use_reentrant_false(self):
207
        pass
208

209
    @unittest.skip(reason="XCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
210
    def test_save_load_fast_init_from_base(self):
211
        pass
212

213
    @unittest.skip(reason="XCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
214
    def test_save_load_fast_init_to_base(self):
215
        pass
216

217
    @slow
218
    def test_model_from_pretrained(self):
219
        for model_name in XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
220
            model = XCLIPVisionModel.from_pretrained(model_name)
221
            self.assertIsNotNone(model)
222

223
    def test_gradient_checkpointing_backward_compatibility(self):
224
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
225

226
        for model_class in self.all_model_classes:
227
            if not model_class.supports_gradient_checkpointing:
228
                continue
229

230
            print("Model class:", model_class)
231

232
            config.gradient_checkpointing = True
233
            model = model_class(config)
234
            self.assertTrue(model.is_gradient_checkpointing)
235

236
    def test_attention_outputs(self):
237
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
238
        config.return_dict = True
239

240
        # we add 1 here due to the special message token in X-CLIP's vision encoder
241
        seq_len = getattr(self.model_tester, "seq_length", None) + 1
242
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
243

244
        for model_class in self.all_model_classes:
245
            inputs_dict["output_attentions"] = True
246
            inputs_dict["output_hidden_states"] = False
247
            config.return_dict = True
248
            model = model_class(config)
249
            model.to(torch_device)
250
            model.eval()
251
            with torch.no_grad():
252
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
253
            self.assertEqual(len(outputs.attentions), self.model_tester.num_hidden_layers)
254

255
            # check that output_attentions also work using config
256
            del inputs_dict["output_attentions"]
257
            config.output_attentions = True
258
            model = model_class(config)
259
            model.to(torch_device)
260
            model.eval()
261
            with torch.no_grad():
262
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
263
            self.assertEqual(len(outputs.attentions), self.model_tester.num_hidden_layers)
264

265
            self.assertListEqual(
266
                list(outputs.attentions[0].shape[-3:]),
267
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
268
            )
269
            out_len = len(outputs)
270

271
            # Check attention is always last and order is fine
272
            inputs_dict["output_attentions"] = True
273
            inputs_dict["output_hidden_states"] = True
274
            model = model_class(config)
275
            model.to(torch_device)
276
            model.eval()
277
            with torch.no_grad():
278
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
279

280
            self.assertEqual(out_len + 1, len(outputs))
281

282
            self_attentions = outputs.attentions
283

284
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
285
            self.assertListEqual(
286
                list(self_attentions[0].shape[-3:]),
287
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
288
            )
289

290
    @require_torch_multi_gpu
291
    def test_multi_gpu_data_parallel_forward(self):
292
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
293

294
        # some params shouldn't be scattered by nn.DataParallel
295
        # so just remove them if they are present.
296
        blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
297
        for k in blacklist_non_batched_params:
298
            inputs_dict.pop(k, None)
299

300
        # move input tensors to cuda:O
301
        for k, v in inputs_dict.items():
302
            if torch.is_tensor(v):
303
                inputs_dict[k] = v.to(0)
304

305
        for model_class in self.all_model_classes:
306
            model = model_class(config=config)
307
            model.to(0)
308
            model.eval()
309

310
            # Wrap model in nn.DataParallel
311
            model = nn.DataParallel(model)
312
            with torch.no_grad():
313
                test = self._prepare_for_class(inputs_dict, model_class)
314
                for k, v in test.items():
315
                    if isinstance(v, torch.Tensor):
316
                        print(k, v.shape)
317
                    else:
318
                        print(k, v)
319
                _ = model(**self._prepare_for_class(inputs_dict, model_class))
320

321

322
class XCLIPTextModelTester:
323
    def __init__(
324
        self,
325
        parent,
326
        batch_size=8,
327
        seq_length=7,
328
        is_training=True,
329
        use_input_mask=True,
330
        use_labels=True,
331
        vocab_size=99,
332
        hidden_size=32,
333
        num_hidden_layers=2,
334
        num_attention_heads=4,
335
        intermediate_size=37,
336
        dropout=0.1,
337
        attention_dropout=0.1,
338
        max_position_embeddings=512,
339
        initializer_range=0.02,
340
        scope=None,
341
    ):
342
        self.parent = parent
343
        self.batch_size = batch_size
344
        self.seq_length = seq_length
345
        self.is_training = is_training
346
        self.use_input_mask = use_input_mask
347
        self.use_labels = use_labels
348
        self.vocab_size = vocab_size
349
        self.hidden_size = hidden_size
350
        self.num_hidden_layers = num_hidden_layers
351
        self.num_attention_heads = num_attention_heads
352
        self.intermediate_size = intermediate_size
353
        self.dropout = dropout
354
        self.attention_dropout = attention_dropout
355
        self.max_position_embeddings = max_position_embeddings
356
        self.initializer_range = initializer_range
357
        self.scope = scope
358

359
    def prepare_config_and_inputs(self):
360
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
361

362
        input_mask = None
363
        if self.use_input_mask:
364
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
365

366
        if input_mask is not None:
367
            batch_size, seq_length = input_mask.shape
368
            rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
369
            for batch_idx, start_index in enumerate(rnd_start_indices):
370
                input_mask[batch_idx, :start_index] = 1
371
                input_mask[batch_idx, start_index:] = 0
372

373
        config = self.get_config()
374

375
        return config, input_ids, input_mask
376

377
    def get_config(self):
378
        return XCLIPTextConfig(
379
            vocab_size=self.vocab_size,
380
            hidden_size=self.hidden_size,
381
            num_hidden_layers=self.num_hidden_layers,
382
            num_attention_heads=self.num_attention_heads,
383
            intermediate_size=self.intermediate_size,
384
            dropout=self.dropout,
385
            attention_dropout=self.attention_dropout,
386
            max_position_embeddings=self.max_position_embeddings,
387
            initializer_range=self.initializer_range,
388
        )
389

390
    def create_and_check_model(self, config, input_ids, input_mask):
391
        model = XCLIPTextModel(config=config)
392
        model.to(torch_device)
393
        model.eval()
394
        with torch.no_grad():
395
            result = model(input_ids, attention_mask=input_mask)
396
            result = model(input_ids)
397
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
398
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
399

400
    def prepare_config_and_inputs_for_common(self):
401
        config_and_inputs = self.prepare_config_and_inputs()
402
        config, input_ids, input_mask = config_and_inputs
403
        inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
404
        return config, inputs_dict
405

406

407
@require_torch
408
class XCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
409
    all_model_classes = (XCLIPTextModel,) if is_torch_available() else ()
410
    fx_compatible = False
411
    test_pruning = False
412
    test_head_masking = False
413

414
    def setUp(self):
415
        self.model_tester = XCLIPTextModelTester(self)
416
        self.config_tester = ConfigTester(self, config_class=XCLIPTextConfig, hidden_size=37)
417

418
    def test_config(self):
419
        self.config_tester.run_common_tests()
420

421
    def test_model(self):
422
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
423
        self.model_tester.create_and_check_model(*config_and_inputs)
424

425
    def test_training(self):
426
        pass
427

428
    def test_training_gradient_checkpointing(self):
429
        pass
430

431
    @unittest.skip(
432
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
433
    )
434
    def test_training_gradient_checkpointing_use_reentrant(self):
435
        pass
436

437
    @unittest.skip(
438
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
439
    )
440
    def test_training_gradient_checkpointing_use_reentrant_false(self):
441
        pass
442

443
    @unittest.skip(reason="X-CLIP does not use inputs_embeds")
444
    def test_inputs_embeds(self):
445
        pass
446

447
    @unittest.skip(reason="XCLIPTextModel has no base class and is not available in MODEL_MAPPING")
448
    def test_save_load_fast_init_from_base(self):
449
        pass
450

451
    @unittest.skip(reason="XCLIPTextModel has no base class and is not available in MODEL_MAPPING")
452
    def test_save_load_fast_init_to_base(self):
453
        pass
454

455
    @slow
456
    def test_model_from_pretrained(self):
457
        for model_name in XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
458
            model = XCLIPTextModel.from_pretrained(model_name)
459
            self.assertIsNotNone(model)
460

461

462
class XCLIPModelTester:
463
    def __init__(
464
        self,
465
        parent,
466
        text_kwargs=None,
467
        vision_kwargs=None,
468
        projection_dim=64,
469
        mit_hidden_size=64,
470
        is_training=True,
471
    ):
472
        if text_kwargs is None:
473
            text_kwargs = {}
474
        if vision_kwargs is None:
475
            vision_kwargs = {}
476

477
        self.parent = parent
478
        self.projection_dim = projection_dim
479
        self.mit_hidden_size = mit_hidden_size
480
        self.text_model_tester = XCLIPTextModelTester(parent, **text_kwargs)
481
        self.vision_model_tester = XCLIPVisionModelTester(parent, **vision_kwargs)
482
        self.is_training = is_training
483

484
    def prepare_config_and_inputs(self):
485
        text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
486
        vision_config, _ = self.vision_model_tester.prepare_config_and_inputs()
487
        pixel_values = floats_tensor(
488
            [
489
                self.vision_model_tester.batch_size,
490
                self.vision_model_tester.num_frames,
491
                self.vision_model_tester.num_channels,
492
                self.vision_model_tester.image_size,
493
                self.vision_model_tester.image_size,
494
            ]
495
        )
496

497
        config = self.get_config()
498

499
        return config, input_ids, attention_mask, pixel_values
500

501
    def get_config(self):
502
        return XCLIPConfig.from_text_vision_configs(
503
            self.text_model_tester.get_config(),
504
            self.vision_model_tester.get_config(),
505
            projection_dim=self.projection_dim,
506
        )
507

508
    def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
509
        model = XCLIPModel(config).to(torch_device).eval()
510
        with torch.no_grad():
511
            result = model(input_ids, pixel_values, attention_mask)
512
        self.parent.assertEqual(
513
            result.logits_per_video.shape,
514
            (self.vision_model_tester.batch_size, self.text_model_tester.batch_size),
515
        )
516
        self.parent.assertEqual(
517
            result.logits_per_text.shape,
518
            (self.text_model_tester.batch_size, self.vision_model_tester.batch_size),
519
        )
520

521
    def prepare_config_and_inputs_for_common(self):
522
        config_and_inputs = self.prepare_config_and_inputs()
523
        config, input_ids, attention_mask, pixel_values = config_and_inputs
524
        inputs_dict = {
525
            "input_ids": input_ids,
526
            "attention_mask": attention_mask,
527
            "pixel_values": pixel_values,
528
            "return_loss": True,
529
        }
530
        return config, inputs_dict
531

532

533
@require_torch
534
class XCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
535
    all_model_classes = (XCLIPModel,) if is_torch_available() else ()
536
    pipeline_model_mapping = {"feature-extraction": XCLIPModel} if is_torch_available() else {}
537
    fx_compatible = False
538
    test_head_masking = False
539
    test_pruning = False
540
    test_resize_embeddings = False
541
    test_attention_outputs = False
542
    test_torchscript = False
543
    maxdiff = None
544

545
    def setUp(self):
546
        self.model_tester = XCLIPModelTester(self)
547

548
    def test_model(self):
549
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
550
        self.model_tester.create_and_check_model(*config_and_inputs)
551

552
    @unittest.skip(reason="Hidden_states is tested in individual model tests")
553
    def test_hidden_states_output(self):
554
        pass
555

556
    @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
557
    def test_inputs_embeds(self):
558
        pass
559

560
    @unittest.skip(reason="Retain_grad is tested in individual model tests")
561
    def test_retain_grad_hidden_states_attentions(self):
562
        pass
563

564
    @unittest.skip(reason="XCLIPModel does not have input/output embeddings")
565
    def test_model_common_attributes(self):
566
        pass
567

568
    @unittest.skip(reason="XCLIPModel does not support feedforward chunking")
569
    def test_feed_forward_chunking(self):
570
        pass
571

572
    # override as the `logit_scale`, `prompts_generator.alpha` parameters require special treatment
573
    def test_initialization(self):
574
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
575

576
        configs_no_init = _config_zero_init(config)
577
        for model_class in self.all_model_classes:
578
            model = model_class(config=configs_no_init)
579
            for name, param in model.named_parameters():
580
                if param.requires_grad:
581
                    # check if `logit_scale` is initilized as per the original implementation
582
                    if name == "logit_scale":
583
                        self.assertAlmostEqual(
584
                            param.data.item(),
585
                            np.log(1 / 0.07),
586
                            delta=1e-3,
587
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
588
                        )
589
                    elif name == "prompts_generator.alpha":
590
                        self.assertAlmostEqual(param.data.mean().item(), model.config.prompt_alpha)
591
                    else:
592
                        self.assertIn(
593
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
594
                            [0.0, 1.0],
595
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
596
                        )
597

598
    def _create_and_check_torchscript(self, config, inputs_dict):
599
        if not self.test_torchscript:
600
            return
601

602
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
603
        configs_no_init.torchscript = True
604
        configs_no_init.return_dict = False
605
        for model_class in self.all_model_classes:
606
            model = model_class(config=configs_no_init)
607
            model.to(torch_device)
608
            model.eval()
609

610
            try:
611
                input_ids = inputs_dict["input_ids"]
612
                pixel_values = inputs_dict["pixel_values"]  # X-CLIP needs pixel_values
613
                traced_model = torch.jit.trace(model, (input_ids, pixel_values))
614
            except RuntimeError:
615
                self.fail("Couldn't trace module.")
616

617
            with tempfile.TemporaryDirectory() as tmp_dir_name:
618
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
619

620
                try:
621
                    torch.jit.save(traced_model, pt_file_name)
622
                except Exception:
623
                    self.fail("Couldn't save module.")
624

625
                try:
626
                    loaded_model = torch.jit.load(pt_file_name)
627
                except Exception:
628
                    self.fail("Couldn't load module.")
629

630
            model.to(torch_device)
631
            model.eval()
632

633
            loaded_model.to(torch_device)
634
            loaded_model.eval()
635

636
            model_state_dict = model.state_dict()
637
            loaded_model_state_dict = loaded_model.state_dict()
638

639
            non_persistent_buffers = {}
640
            for key in loaded_model_state_dict.keys():
641
                if key not in model_state_dict.keys():
642
                    non_persistent_buffers[key] = loaded_model_state_dict[key]
643

644
            loaded_model_state_dict = {
645
                key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
646
            }
647

648
            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
649

650
            model_buffers = list(model.buffers())
651
            for non_persistent_buffer in non_persistent_buffers.values():
652
                found_buffer = False
653
                for i, model_buffer in enumerate(model_buffers):
654
                    if torch.equal(non_persistent_buffer, model_buffer):
655
                        found_buffer = True
656
                        break
657

658
                self.assertTrue(found_buffer)
659
                model_buffers.pop(i)
660

661
            models_equal = True
662
            for layer_name, p1 in model_state_dict.items():
663
                p2 = loaded_model_state_dict[layer_name]
664
                if p1.data.ne(p2.data).sum() > 0:
665
                    models_equal = False
666

667
            self.assertTrue(models_equal)
668

669
    def test_load_vision_text_config(self):
670
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
671

672
        # Save XCLIPConfig and check if we can load XCLIPVisionConfig from it
673
        with tempfile.TemporaryDirectory() as tmp_dir_name:
674
            config.save_pretrained(tmp_dir_name)
675
            vision_config = XCLIPVisionConfig.from_pretrained(tmp_dir_name)
676
            self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
677

678
        # Save XCLIPConfig and check if we can load XCLIPTextConfig from it
679
        with tempfile.TemporaryDirectory() as tmp_dir_name:
680
            config.save_pretrained(tmp_dir_name)
681
            text_config = XCLIPTextConfig.from_pretrained(tmp_dir_name)
682
            self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
683

684
    @slow
685
    def test_model_from_pretrained(self):
686
        for model_name in XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
687
            model = XCLIPModel.from_pretrained(model_name)
688
            self.assertIsNotNone(model)
689

690

691
# We will verify our results on a spaghetti video
692
def prepare_video():
693
    file = hf_hub_download(
694
        repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_8_frames.npy", repo_type="dataset"
695
    )
696
    video = np.load(file)
697
    return list(video)
698

699

700
@require_vision
701
@require_torch
702
class XCLIPModelIntegrationTest(unittest.TestCase):
703
    @slow
704
    def test_inference(self):
705
        model_name = "microsoft/xclip-base-patch32"
706
        model = XCLIPModel.from_pretrained(model_name).to(torch_device)
707
        processor = XCLIPProcessor.from_pretrained(model_name)
708

709
        video = prepare_video()
710
        inputs = processor(
711
            text=["playing sports", "eating spaghetti", "go shopping"], videos=video, return_tensors="pt", padding=True
712
        ).to(torch_device)
713

714
        # forward pass
715
        with torch.no_grad():
716
            outputs = model(**inputs)
717

718
        # verify the logits
719
        self.assertEqual(
720
            outputs.logits_per_video.shape,
721
            torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
722
        )
723
        self.assertEqual(
724
            outputs.logits_per_text.shape,
725
            torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
726
        )
727

728
        expected_logits = torch.tensor([[14.0181, 20.2771, 14.4776]], device=torch_device)
729

730
        self.assertTrue(torch.allclose(outputs.logits_per_video, expected_logits, atol=1e-3))
731

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.