transformers

Форк
0
/
test_modeling_chinese_clip.py 
738 строк · 27.3 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch Chinese-CLIP model. """
16

17
import inspect
18
import os
19
import tempfile
20
import unittest
21

22
import numpy as np
23
import requests
24

25
from transformers import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
26
from transformers.models.auto import get_values
27
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
28
from transformers.utils import is_torch_available, is_vision_available
29

30
from ...test_configuration_common import ConfigTester
31
from ...test_modeling_common import (
32
    ModelTesterMixin,
33
    _config_zero_init,
34
    floats_tensor,
35
    ids_tensor,
36
    random_attention_mask,
37
)
38
from ...test_pipeline_mixin import PipelineTesterMixin
39

40

41
if is_torch_available():
42
    import torch
43
    from torch import nn
44

45
    from transformers import (
46
        MODEL_FOR_PRETRAINING_MAPPING,
47
        ChineseCLIPModel,
48
        ChineseCLIPTextModel,
49
        ChineseCLIPVisionModel,
50
    )
51
    from transformers.models.chinese_clip.modeling_chinese_clip import CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
52

53

54
if is_vision_available():
55
    from PIL import Image
56

57
    from transformers import ChineseCLIPProcessor
58

59

60
class ChineseCLIPTextModelTester:
61
    def __init__(
62
        self,
63
        parent,
64
        batch_size=13,
65
        seq_length=7,
66
        is_training=True,
67
        use_input_mask=True,
68
        use_token_type_ids=True,
69
        use_labels=True,
70
        vocab_size=99,
71
        hidden_size=32,
72
        num_hidden_layers=2,
73
        num_attention_heads=4,
74
        intermediate_size=37,
75
        hidden_act="gelu",
76
        hidden_dropout_prob=0.1,
77
        attention_probs_dropout_prob=0.1,
78
        max_position_embeddings=512,
79
        type_vocab_size=16,
80
        type_sequence_label_size=2,
81
        initializer_range=0.02,
82
        num_labels=3,
83
        num_choices=4,
84
        scope=None,
85
    ):
86
        self.parent = parent
87
        self.batch_size = batch_size
88
        self.seq_length = seq_length
89
        self.is_training = is_training
90
        self.use_input_mask = use_input_mask
91
        self.use_token_type_ids = use_token_type_ids
92
        self.use_labels = use_labels
93
        self.vocab_size = vocab_size
94
        self.hidden_size = hidden_size
95
        self.num_hidden_layers = num_hidden_layers
96
        self.num_attention_heads = num_attention_heads
97
        self.intermediate_size = intermediate_size
98
        self.hidden_act = hidden_act
99
        self.hidden_dropout_prob = hidden_dropout_prob
100
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
101
        self.max_position_embeddings = max_position_embeddings
102
        self.type_vocab_size = type_vocab_size
103
        self.type_sequence_label_size = type_sequence_label_size
104
        self.initializer_range = initializer_range
105
        self.num_labels = num_labels
106
        self.num_choices = num_choices
107
        self.scope = scope
108

109
    def prepare_config_and_inputs(self):
110
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
111

112
        input_mask = None
113
        if self.use_input_mask:
114
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
115

116
        token_type_ids = None
117
        if self.use_token_type_ids:
118
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
119

120
        sequence_labels = None
121
        token_labels = None
122
        choice_labels = None
123
        if self.use_labels:
124
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
125
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
126
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
127

128
        config = self.get_config()
129

130
        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
131

132
    def get_config(self):
133
        """
134
        Returns a tiny configuration by default.
135
        """
136
        return ChineseCLIPTextConfig(
137
            vocab_size=self.vocab_size,
138
            hidden_size=self.hidden_size,
139
            num_hidden_layers=self.num_hidden_layers,
140
            num_attention_heads=self.num_attention_heads,
141
            intermediate_size=self.intermediate_size,
142
            hidden_act=self.hidden_act,
143
            hidden_dropout_prob=self.hidden_dropout_prob,
144
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
145
            max_position_embeddings=self.max_position_embeddings,
146
            type_vocab_size=self.type_vocab_size,
147
            is_decoder=False,
148
            initializer_range=self.initializer_range,
149
        )
150

151
    def prepare_config_and_inputs_for_decoder(self):
152
        (
153
            config,
154
            input_ids,
155
            token_type_ids,
156
            input_mask,
157
            sequence_labels,
158
            token_labels,
159
            choice_labels,
160
        ) = self.prepare_config_and_inputs()
161

162
        config.is_decoder = True
163
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
164
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
165

166
        return (
167
            config,
168
            input_ids,
169
            token_type_ids,
170
            input_mask,
171
            sequence_labels,
172
            token_labels,
173
            choice_labels,
174
            encoder_hidden_states,
175
            encoder_attention_mask,
176
        )
177

178
    def create_and_check_model(
179
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
180
    ):
181
        model = ChineseCLIPTextModel(config=config)
182
        model.to(torch_device)
183
        model.eval()
184
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
185
        result = model(input_ids, token_type_ids=token_type_ids)
186
        result = model(input_ids)
187
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
188
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
189

190
    def create_and_check_model_as_decoder(
191
        self,
192
        config,
193
        input_ids,
194
        token_type_ids,
195
        input_mask,
196
        sequence_labels,
197
        token_labels,
198
        choice_labels,
199
        encoder_hidden_states,
200
        encoder_attention_mask,
201
    ):
202
        config.add_cross_attention = True
203
        model = ChineseCLIPTextModel(config)
204
        model.to(torch_device)
205
        model.eval()
206
        result = model(
207
            input_ids,
208
            attention_mask=input_mask,
209
            token_type_ids=token_type_ids,
210
            encoder_hidden_states=encoder_hidden_states,
211
            encoder_attention_mask=encoder_attention_mask,
212
        )
213
        result = model(
214
            input_ids,
215
            attention_mask=input_mask,
216
            token_type_ids=token_type_ids,
217
            encoder_hidden_states=encoder_hidden_states,
218
        )
219
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
220
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
221
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
222

223
    def prepare_config_and_inputs_for_common(self):
224
        config_and_inputs = self.prepare_config_and_inputs()
225
        (
226
            config,
227
            input_ids,
228
            token_type_ids,
229
            input_mask,
230
            sequence_labels,
231
            token_labels,
232
            choice_labels,
233
        ) = config_and_inputs
234
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
235
        return config, inputs_dict
236

237

238
class ChineseCLIPVisionModelTester:
239
    def __init__(
240
        self,
241
        parent,
242
        batch_size=12,
243
        image_size=30,
244
        patch_size=2,
245
        num_channels=3,
246
        is_training=True,
247
        hidden_size=32,
248
        projection_dim=32,
249
        num_hidden_layers=2,
250
        num_attention_heads=4,
251
        intermediate_size=37,
252
        dropout=0.1,
253
        attention_dropout=0.1,
254
        initializer_range=0.02,
255
        scope=None,
256
    ):
257
        self.parent = parent
258
        self.batch_size = batch_size
259
        self.image_size = image_size
260
        self.patch_size = patch_size
261
        self.num_channels = num_channels
262
        self.is_training = is_training
263
        self.hidden_size = hidden_size
264
        self.projection_dim = projection_dim
265
        self.num_hidden_layers = num_hidden_layers
266
        self.num_attention_heads = num_attention_heads
267
        self.intermediate_size = intermediate_size
268
        self.dropout = dropout
269
        self.attention_dropout = attention_dropout
270
        self.initializer_range = initializer_range
271
        self.scope = scope
272

273
        # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
274
        num_patches = (image_size // patch_size) ** 2
275
        self.seq_length = num_patches + 1
276

277
    def prepare_config_and_inputs(self):
278
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
279
        config = self.get_config()
280

281
        return config, pixel_values
282

283
    def get_config(self):
284
        return ChineseCLIPVisionConfig(
285
            image_size=self.image_size,
286
            patch_size=self.patch_size,
287
            num_channels=self.num_channels,
288
            hidden_size=self.hidden_size,
289
            projection_dim=self.projection_dim,
290
            num_hidden_layers=self.num_hidden_layers,
291
            num_attention_heads=self.num_attention_heads,
292
            intermediate_size=self.intermediate_size,
293
            dropout=self.dropout,
294
            attention_dropout=self.attention_dropout,
295
            initializer_range=self.initializer_range,
296
        )
297

298
    def create_and_check_model(self, config, pixel_values):
299
        model = ChineseCLIPVisionModel(config=config)
300
        model.to(torch_device)
301
        model.eval()
302
        with torch.no_grad():
303
            result = model(pixel_values)
304
        # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
305
        image_size = (self.image_size, self.image_size)
306
        patch_size = (self.patch_size, self.patch_size)
307
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
308
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
309
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
310

311
    def prepare_config_and_inputs_for_common(self):
312
        config_and_inputs = self.prepare_config_and_inputs()
313
        config, pixel_values = config_and_inputs
314
        inputs_dict = {"pixel_values": pixel_values}
315
        return config, inputs_dict
316

317

318
@require_torch
319
class ChineseCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
320
    all_model_classes = (ChineseCLIPTextModel,) if is_torch_available() else ()
321
    fx_compatible = False
322

323
    # special case for ForPreTraining model
324
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
325
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
326

327
        if return_labels:
328
            if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
329
                inputs_dict["labels"] = torch.zeros(
330
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
331
                )
332
                inputs_dict["next_sentence_label"] = torch.zeros(
333
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
334
                )
335
        return inputs_dict
336

337
    def setUp(self):
338
        self.model_tester = ChineseCLIPTextModelTester(self)
339
        self.config_tester = ConfigTester(self, config_class=ChineseCLIPTextConfig, hidden_size=37)
340

341
    def test_config(self):
342
        self.config_tester.run_common_tests()
343

344
    def test_model(self):
345
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
346
        self.model_tester.create_and_check_model(*config_and_inputs)
347

348
    def test_model_various_embeddings(self):
349
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
350
        for type in ["absolute", "relative_key", "relative_key_query"]:
351
            config_and_inputs[0].position_embedding_type = type
352
            self.model_tester.create_and_check_model(*config_and_inputs)
353

354
    def test_model_as_decoder(self):
355
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
356
        self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
357

358
    def test_model_as_decoder_with_default_input_mask(self):
359
        # This regression test was failing with PyTorch < 1.3
360
        (
361
            config,
362
            input_ids,
363
            token_type_ids,
364
            input_mask,
365
            sequence_labels,
366
            token_labels,
367
            choice_labels,
368
            encoder_hidden_states,
369
            encoder_attention_mask,
370
        ) = self.model_tester.prepare_config_and_inputs_for_decoder()
371

372
        input_mask = None
373

374
        self.model_tester.create_and_check_model_as_decoder(
375
            config,
376
            input_ids,
377
            token_type_ids,
378
            input_mask,
379
            sequence_labels,
380
            token_labels,
381
            choice_labels,
382
            encoder_hidden_states,
383
            encoder_attention_mask,
384
        )
385

386
    @slow
387
    def test_model_from_pretrained(self):
388
        for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
389
            model = ChineseCLIPTextModel.from_pretrained(model_name)
390
            self.assertIsNotNone(model)
391

392
    def test_training(self):
393
        pass
394

395
    def test_training_gradient_checkpointing(self):
396
        pass
397

398
    @unittest.skip(
399
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
400
    )
401
    def test_training_gradient_checkpointing_use_reentrant(self):
402
        pass
403

404
    @unittest.skip(
405
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
406
    )
407
    def test_training_gradient_checkpointing_use_reentrant_false(self):
408
        pass
409

410
    @unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING")
411
    def test_save_load_fast_init_from_base(self):
412
        pass
413

414
    @unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING")
415
    def test_save_load_fast_init_to_base(self):
416
        pass
417

418

419
@require_torch
420
class ChineseCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
421
    """
422
    Here we also overwrite some of the tests of test_modeling_common.py, as CHINESE_CLIP does not use input_ids, inputs_embeds,
423
    attention_mask and seq_length.
424
    """
425

426
    all_model_classes = (ChineseCLIPVisionModel,) if is_torch_available() else ()
427
    fx_compatible = False
428
    test_pruning = False
429
    test_resize_embeddings = False
430
    test_head_masking = False
431

432
    def setUp(self):
433
        self.model_tester = ChineseCLIPVisionModelTester(self)
434
        self.config_tester = ConfigTester(
435
            self, config_class=ChineseCLIPVisionConfig, has_text_modality=False, hidden_size=37
436
        )
437

438
    def test_config(self):
439
        self.config_tester.run_common_tests()
440

441
    @unittest.skip(reason="CHINESE_CLIP does not use inputs_embeds")
442
    def test_inputs_embeds(self):
443
        pass
444

445
    def test_model_common_attributes(self):
446
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
447

448
        for model_class in self.all_model_classes:
449
            model = model_class(config)
450
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
451
            x = model.get_output_embeddings()
452
            self.assertTrue(x is None or isinstance(x, nn.Linear))
453

454
    def test_forward_signature(self):
455
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
456

457
        for model_class in self.all_model_classes:
458
            model = model_class(config)
459
            signature = inspect.signature(model.forward)
460
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
461
            arg_names = [*signature.parameters.keys()]
462

463
            expected_arg_names = ["pixel_values"]
464
            self.assertListEqual(arg_names[:1], expected_arg_names)
465

466
    def test_model(self):
467
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
468
        self.model_tester.create_and_check_model(*config_and_inputs)
469

470
    def test_training(self):
471
        pass
472

473
    def test_training_gradient_checkpointing(self):
474
        pass
475

476
    @unittest.skip(
477
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
478
    )
479
    def test_training_gradient_checkpointing_use_reentrant(self):
480
        pass
481

482
    @unittest.skip(
483
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
484
    )
485
    def test_training_gradient_checkpointing_use_reentrant_false(self):
486
        pass
487

488
    @unittest.skip(reason="ChineseCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
489
    def test_save_load_fast_init_from_base(self):
490
        pass
491

492
    @unittest.skip(reason="ChineseCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
493
    def test_save_load_fast_init_to_base(self):
494
        pass
495

496
    @slow
497
    def test_model_from_pretrained(self):
498
        for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
499
            model = ChineseCLIPVisionModel.from_pretrained(model_name)
500
            self.assertIsNotNone(model)
501

502

503
class ChineseCLIPModelTester:
504
    def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
505
        if text_kwargs is None:
506
            text_kwargs = {}
507
        if vision_kwargs is None:
508
            vision_kwargs = {}
509

510
        self.parent = parent
511
        self.text_model_tester = ChineseCLIPTextModelTester(parent, **text_kwargs)
512
        self.vision_model_tester = ChineseCLIPVisionModelTester(parent, **vision_kwargs)
513
        self.is_training = is_training
514

515
    def prepare_config_and_inputs(self):
516
        (
517
            config,
518
            input_ids,
519
            token_type_ids,
520
            attention_mask,
521
            _,
522
            __,
523
            ___,
524
        ) = self.text_model_tester.prepare_config_and_inputs()
525
        vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
526

527
        config = self.get_config()
528

529
        return config, input_ids, token_type_ids, attention_mask, pixel_values
530

531
    def get_config(self):
532
        return ChineseCLIPConfig.from_text_vision_configs(
533
            self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
534
        )
535

536
    def create_and_check_model(self, config, input_ids, token_type_ids, attention_mask, pixel_values):
537
        model = ChineseCLIPModel(config).to(torch_device).eval()
538
        with torch.no_grad():
539
            result = model(input_ids, pixel_values, attention_mask, token_type_ids)
540
        self.parent.assertEqual(
541
            result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
542
        )
543
        self.parent.assertEqual(
544
            result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
545
        )
546

547
    def prepare_config_and_inputs_for_common(self):
548
        config_and_inputs = self.prepare_config_and_inputs()
549
        config, input_ids, token_type_ids, attention_mask, pixel_values = config_and_inputs
550
        inputs_dict = {
551
            "input_ids": input_ids,
552
            "token_type_ids": token_type_ids,
553
            "attention_mask": attention_mask,
554
            "pixel_values": pixel_values,
555
            "return_loss": True,
556
        }
557
        return config, inputs_dict
558

559

560
@require_torch
561
class ChineseCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
562
    all_model_classes = (ChineseCLIPModel,) if is_torch_available() else ()
563
    pipeline_model_mapping = {"feature-extraction": ChineseCLIPModel} if is_torch_available() else {}
564
    fx_compatible = False
565
    test_head_masking = False
566
    test_pruning = False
567
    test_resize_embeddings = False
568
    test_attention_outputs = False
569

570
    def setUp(self):
571
        text_kwargs = {"use_labels": False, "batch_size": 12}
572
        vision_kwargs = {"batch_size": 12}
573
        self.model_tester = ChineseCLIPModelTester(self, text_kwargs, vision_kwargs)
574

575
    def test_model(self):
576
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
577
        self.model_tester.create_and_check_model(*config_and_inputs)
578

579
    @unittest.skip(reason="Hidden_states is tested in individual model tests")
580
    def test_hidden_states_output(self):
581
        pass
582

583
    @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
584
    def test_inputs_embeds(self):
585
        pass
586

587
    @unittest.skip(reason="Retain_grad is tested in individual model tests")
588
    def test_retain_grad_hidden_states_attentions(self):
589
        pass
590

591
    @unittest.skip(reason="ChineseCLIPModel does not have input/output embeddings")
592
    def test_model_common_attributes(self):
593
        pass
594

595
    # override as the `logit_scale` parameter initilization is different for CHINESE_CLIP
596
    def test_initialization(self):
597
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
598

599
        configs_no_init = _config_zero_init(config)
600
        for sub_config_key in ("vision_config", "text_config"):
601
            sub_config = getattr(configs_no_init, sub_config_key, {})
602
            setattr(configs_no_init, sub_config_key, _config_zero_init(sub_config))
603
        for model_class in self.all_model_classes:
604
            model = model_class(config=configs_no_init)
605
            for name, param in model.named_parameters():
606
                if param.requires_grad:
607
                    # check if `logit_scale` is initilized as per the original implementation
608
                    if name == "logit_scale":
609
                        self.assertAlmostEqual(
610
                            param.data.item(),
611
                            np.log(1 / 0.07),
612
                            delta=1e-3,
613
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
614
                        )
615
                    else:
616
                        self.assertIn(
617
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
618
                            [0.0, 1.0],
619
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
620
                        )
621

622
    def _create_and_check_torchscript(self, config, inputs_dict):
623
        if not self.test_torchscript:
624
            return
625

626
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
627
        configs_no_init.torchscript = True
628
        configs_no_init.return_dict = False
629
        for model_class in self.all_model_classes:
630
            model = model_class(config=configs_no_init)
631
            model.to(torch_device)
632
            model.eval()
633

634
            try:
635
                input_ids = inputs_dict["input_ids"]
636
                pixel_values = inputs_dict["pixel_values"]  # CHINESE_CLIP needs pixel_values
637
                traced_model = torch.jit.trace(model, (input_ids, pixel_values))
638
            except RuntimeError:
639
                self.fail("Couldn't trace module.")
640

641
            with tempfile.TemporaryDirectory() as tmp_dir_name:
642
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
643

644
                try:
645
                    torch.jit.save(traced_model, pt_file_name)
646
                except Exception:
647
                    self.fail("Couldn't save module.")
648

649
                try:
650
                    loaded_model = torch.jit.load(pt_file_name)
651
                except Exception:
652
                    self.fail("Couldn't load module.")
653

654
            model.to(torch_device)
655
            model.eval()
656

657
            loaded_model.to(torch_device)
658
            loaded_model.eval()
659

660
            model_state_dict = model.state_dict()
661
            loaded_model_state_dict = loaded_model.state_dict()
662

663
            non_persistent_buffers = {}
664
            for key in loaded_model_state_dict.keys():
665
                if key not in model_state_dict.keys():
666
                    non_persistent_buffers[key] = loaded_model_state_dict[key]
667

668
            loaded_model_state_dict = {
669
                key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
670
            }
671

672
            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
673

674
            model_buffers = list(model.buffers())
675
            for non_persistent_buffer in non_persistent_buffers.values():
676
                found_buffer = False
677
                for i, model_buffer in enumerate(model_buffers):
678
                    if torch.equal(non_persistent_buffer, model_buffer):
679
                        found_buffer = True
680
                        break
681

682
                self.assertTrue(found_buffer)
683
                model_buffers.pop(i)
684

685
            models_equal = True
686
            for layer_name, p1 in model_state_dict.items():
687
                p2 = loaded_model_state_dict[layer_name]
688
                if p1.data.ne(p2.data).sum() > 0:
689
                    models_equal = False
690

691
            self.assertTrue(models_equal)
692

693
    @slow
694
    def test_model_from_pretrained(self):
695
        for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
696
            model = ChineseCLIPModel.from_pretrained(model_name)
697
            self.assertIsNotNone(model)
698

699

700
# We will verify our results on an image of Pikachu
701
def prepare_img():
702
    url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
703
    im = Image.open(requests.get(url, stream=True).raw)
704
    return im
705

706

707
@require_vision
708
@require_torch
709
class ChineseCLIPModelIntegrationTest(unittest.TestCase):
710
    @slow
711
    def test_inference(self):
712
        model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
713
        model = ChineseCLIPModel.from_pretrained(model_name).to(torch_device)
714
        processor = ChineseCLIPProcessor.from_pretrained(model_name)
715

716
        image = prepare_img()
717
        inputs = processor(
718
            text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], images=image, padding=True, return_tensors="pt"
719
        ).to(torch_device)
720

721
        # forward pass
722
        with torch.no_grad():
723
            outputs = model(**inputs)
724

725
        # verify the logits
726
        self.assertEqual(
727
            outputs.logits_per_image.shape,
728
            torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
729
        )
730
        self.assertEqual(
731
            outputs.logits_per_text.shape,
732
            torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
733
        )
734

735
        probs = outputs.logits_per_image.softmax(dim=1)
736
        expected_probs = torch.tensor([[1.2686e-03, 5.4499e-02, 6.7968e-04, 9.4355e-01]], device=torch_device)
737

738
        self.assertTrue(torch.allclose(probs, expected_probs, atol=5e-3))
739

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.