transformers
738 строк · 27.3 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch Chinese-CLIP model. """
16
17import inspect
18import os
19import tempfile
20import unittest
21
22import numpy as np
23import requests
24
25from transformers import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
26from transformers.models.auto import get_values
27from transformers.testing_utils import require_torch, require_vision, slow, torch_device
28from transformers.utils import is_torch_available, is_vision_available
29
30from ...test_configuration_common import ConfigTester
31from ...test_modeling_common import (
32ModelTesterMixin,
33_config_zero_init,
34floats_tensor,
35ids_tensor,
36random_attention_mask,
37)
38from ...test_pipeline_mixin import PipelineTesterMixin
39
40
41if is_torch_available():
42import torch
43from torch import nn
44
45from transformers import (
46MODEL_FOR_PRETRAINING_MAPPING,
47ChineseCLIPModel,
48ChineseCLIPTextModel,
49ChineseCLIPVisionModel,
50)
51from transformers.models.chinese_clip.modeling_chinese_clip import CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
52
53
54if is_vision_available():
55from PIL import Image
56
57from transformers import ChineseCLIPProcessor
58
59
60class ChineseCLIPTextModelTester:
61def __init__(
62self,
63parent,
64batch_size=13,
65seq_length=7,
66is_training=True,
67use_input_mask=True,
68use_token_type_ids=True,
69use_labels=True,
70vocab_size=99,
71hidden_size=32,
72num_hidden_layers=2,
73num_attention_heads=4,
74intermediate_size=37,
75hidden_act="gelu",
76hidden_dropout_prob=0.1,
77attention_probs_dropout_prob=0.1,
78max_position_embeddings=512,
79type_vocab_size=16,
80type_sequence_label_size=2,
81initializer_range=0.02,
82num_labels=3,
83num_choices=4,
84scope=None,
85):
86self.parent = parent
87self.batch_size = batch_size
88self.seq_length = seq_length
89self.is_training = is_training
90self.use_input_mask = use_input_mask
91self.use_token_type_ids = use_token_type_ids
92self.use_labels = use_labels
93self.vocab_size = vocab_size
94self.hidden_size = hidden_size
95self.num_hidden_layers = num_hidden_layers
96self.num_attention_heads = num_attention_heads
97self.intermediate_size = intermediate_size
98self.hidden_act = hidden_act
99self.hidden_dropout_prob = hidden_dropout_prob
100self.attention_probs_dropout_prob = attention_probs_dropout_prob
101self.max_position_embeddings = max_position_embeddings
102self.type_vocab_size = type_vocab_size
103self.type_sequence_label_size = type_sequence_label_size
104self.initializer_range = initializer_range
105self.num_labels = num_labels
106self.num_choices = num_choices
107self.scope = scope
108
109def prepare_config_and_inputs(self):
110input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
111
112input_mask = None
113if self.use_input_mask:
114input_mask = random_attention_mask([self.batch_size, self.seq_length])
115
116token_type_ids = None
117if self.use_token_type_ids:
118token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
119
120sequence_labels = None
121token_labels = None
122choice_labels = None
123if self.use_labels:
124sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
125token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
126choice_labels = ids_tensor([self.batch_size], self.num_choices)
127
128config = self.get_config()
129
130return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
131
132def get_config(self):
133"""
134Returns a tiny configuration by default.
135"""
136return ChineseCLIPTextConfig(
137vocab_size=self.vocab_size,
138hidden_size=self.hidden_size,
139num_hidden_layers=self.num_hidden_layers,
140num_attention_heads=self.num_attention_heads,
141intermediate_size=self.intermediate_size,
142hidden_act=self.hidden_act,
143hidden_dropout_prob=self.hidden_dropout_prob,
144attention_probs_dropout_prob=self.attention_probs_dropout_prob,
145max_position_embeddings=self.max_position_embeddings,
146type_vocab_size=self.type_vocab_size,
147is_decoder=False,
148initializer_range=self.initializer_range,
149)
150
151def prepare_config_and_inputs_for_decoder(self):
152(
153config,
154input_ids,
155token_type_ids,
156input_mask,
157sequence_labels,
158token_labels,
159choice_labels,
160) = self.prepare_config_and_inputs()
161
162config.is_decoder = True
163encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
164encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
165
166return (
167config,
168input_ids,
169token_type_ids,
170input_mask,
171sequence_labels,
172token_labels,
173choice_labels,
174encoder_hidden_states,
175encoder_attention_mask,
176)
177
178def create_and_check_model(
179self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
180):
181model = ChineseCLIPTextModel(config=config)
182model.to(torch_device)
183model.eval()
184result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
185result = model(input_ids, token_type_ids=token_type_ids)
186result = model(input_ids)
187self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
188self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
189
190def create_and_check_model_as_decoder(
191self,
192config,
193input_ids,
194token_type_ids,
195input_mask,
196sequence_labels,
197token_labels,
198choice_labels,
199encoder_hidden_states,
200encoder_attention_mask,
201):
202config.add_cross_attention = True
203model = ChineseCLIPTextModel(config)
204model.to(torch_device)
205model.eval()
206result = model(
207input_ids,
208attention_mask=input_mask,
209token_type_ids=token_type_ids,
210encoder_hidden_states=encoder_hidden_states,
211encoder_attention_mask=encoder_attention_mask,
212)
213result = model(
214input_ids,
215attention_mask=input_mask,
216token_type_ids=token_type_ids,
217encoder_hidden_states=encoder_hidden_states,
218)
219result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
220self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
221self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
222
223def prepare_config_and_inputs_for_common(self):
224config_and_inputs = self.prepare_config_and_inputs()
225(
226config,
227input_ids,
228token_type_ids,
229input_mask,
230sequence_labels,
231token_labels,
232choice_labels,
233) = config_and_inputs
234inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
235return config, inputs_dict
236
237
238class ChineseCLIPVisionModelTester:
239def __init__(
240self,
241parent,
242batch_size=12,
243image_size=30,
244patch_size=2,
245num_channels=3,
246is_training=True,
247hidden_size=32,
248projection_dim=32,
249num_hidden_layers=2,
250num_attention_heads=4,
251intermediate_size=37,
252dropout=0.1,
253attention_dropout=0.1,
254initializer_range=0.02,
255scope=None,
256):
257self.parent = parent
258self.batch_size = batch_size
259self.image_size = image_size
260self.patch_size = patch_size
261self.num_channels = num_channels
262self.is_training = is_training
263self.hidden_size = hidden_size
264self.projection_dim = projection_dim
265self.num_hidden_layers = num_hidden_layers
266self.num_attention_heads = num_attention_heads
267self.intermediate_size = intermediate_size
268self.dropout = dropout
269self.attention_dropout = attention_dropout
270self.initializer_range = initializer_range
271self.scope = scope
272
273# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
274num_patches = (image_size // patch_size) ** 2
275self.seq_length = num_patches + 1
276
277def prepare_config_and_inputs(self):
278pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
279config = self.get_config()
280
281return config, pixel_values
282
283def get_config(self):
284return ChineseCLIPVisionConfig(
285image_size=self.image_size,
286patch_size=self.patch_size,
287num_channels=self.num_channels,
288hidden_size=self.hidden_size,
289projection_dim=self.projection_dim,
290num_hidden_layers=self.num_hidden_layers,
291num_attention_heads=self.num_attention_heads,
292intermediate_size=self.intermediate_size,
293dropout=self.dropout,
294attention_dropout=self.attention_dropout,
295initializer_range=self.initializer_range,
296)
297
298def create_and_check_model(self, config, pixel_values):
299model = ChineseCLIPVisionModel(config=config)
300model.to(torch_device)
301model.eval()
302with torch.no_grad():
303result = model(pixel_values)
304# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
305image_size = (self.image_size, self.image_size)
306patch_size = (self.patch_size, self.patch_size)
307num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
308self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
309self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
310
311def prepare_config_and_inputs_for_common(self):
312config_and_inputs = self.prepare_config_and_inputs()
313config, pixel_values = config_and_inputs
314inputs_dict = {"pixel_values": pixel_values}
315return config, inputs_dict
316
317
318@require_torch
319class ChineseCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
320all_model_classes = (ChineseCLIPTextModel,) if is_torch_available() else ()
321fx_compatible = False
322
323# special case for ForPreTraining model
324def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
325inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
326
327if return_labels:
328if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
329inputs_dict["labels"] = torch.zeros(
330(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
331)
332inputs_dict["next_sentence_label"] = torch.zeros(
333self.model_tester.batch_size, dtype=torch.long, device=torch_device
334)
335return inputs_dict
336
337def setUp(self):
338self.model_tester = ChineseCLIPTextModelTester(self)
339self.config_tester = ConfigTester(self, config_class=ChineseCLIPTextConfig, hidden_size=37)
340
341def test_config(self):
342self.config_tester.run_common_tests()
343
344def test_model(self):
345config_and_inputs = self.model_tester.prepare_config_and_inputs()
346self.model_tester.create_and_check_model(*config_and_inputs)
347
348def test_model_various_embeddings(self):
349config_and_inputs = self.model_tester.prepare_config_and_inputs()
350for type in ["absolute", "relative_key", "relative_key_query"]:
351config_and_inputs[0].position_embedding_type = type
352self.model_tester.create_and_check_model(*config_and_inputs)
353
354def test_model_as_decoder(self):
355config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
356self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
357
358def test_model_as_decoder_with_default_input_mask(self):
359# This regression test was failing with PyTorch < 1.3
360(
361config,
362input_ids,
363token_type_ids,
364input_mask,
365sequence_labels,
366token_labels,
367choice_labels,
368encoder_hidden_states,
369encoder_attention_mask,
370) = self.model_tester.prepare_config_and_inputs_for_decoder()
371
372input_mask = None
373
374self.model_tester.create_and_check_model_as_decoder(
375config,
376input_ids,
377token_type_ids,
378input_mask,
379sequence_labels,
380token_labels,
381choice_labels,
382encoder_hidden_states,
383encoder_attention_mask,
384)
385
386@slow
387def test_model_from_pretrained(self):
388for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
389model = ChineseCLIPTextModel.from_pretrained(model_name)
390self.assertIsNotNone(model)
391
392def test_training(self):
393pass
394
395def test_training_gradient_checkpointing(self):
396pass
397
398@unittest.skip(
399reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
400)
401def test_training_gradient_checkpointing_use_reentrant(self):
402pass
403
404@unittest.skip(
405reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
406)
407def test_training_gradient_checkpointing_use_reentrant_false(self):
408pass
409
410@unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING")
411def test_save_load_fast_init_from_base(self):
412pass
413
414@unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING")
415def test_save_load_fast_init_to_base(self):
416pass
417
418
419@require_torch
420class ChineseCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
421"""
422Here we also overwrite some of the tests of test_modeling_common.py, as CHINESE_CLIP does not use input_ids, inputs_embeds,
423attention_mask and seq_length.
424"""
425
426all_model_classes = (ChineseCLIPVisionModel,) if is_torch_available() else ()
427fx_compatible = False
428test_pruning = False
429test_resize_embeddings = False
430test_head_masking = False
431
432def setUp(self):
433self.model_tester = ChineseCLIPVisionModelTester(self)
434self.config_tester = ConfigTester(
435self, config_class=ChineseCLIPVisionConfig, has_text_modality=False, hidden_size=37
436)
437
438def test_config(self):
439self.config_tester.run_common_tests()
440
441@unittest.skip(reason="CHINESE_CLIP does not use inputs_embeds")
442def test_inputs_embeds(self):
443pass
444
445def test_model_common_attributes(self):
446config, _ = self.model_tester.prepare_config_and_inputs_for_common()
447
448for model_class in self.all_model_classes:
449model = model_class(config)
450self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
451x = model.get_output_embeddings()
452self.assertTrue(x is None or isinstance(x, nn.Linear))
453
454def test_forward_signature(self):
455config, _ = self.model_tester.prepare_config_and_inputs_for_common()
456
457for model_class in self.all_model_classes:
458model = model_class(config)
459signature = inspect.signature(model.forward)
460# signature.parameters is an OrderedDict => so arg_names order is deterministic
461arg_names = [*signature.parameters.keys()]
462
463expected_arg_names = ["pixel_values"]
464self.assertListEqual(arg_names[:1], expected_arg_names)
465
466def test_model(self):
467config_and_inputs = self.model_tester.prepare_config_and_inputs()
468self.model_tester.create_and_check_model(*config_and_inputs)
469
470def test_training(self):
471pass
472
473def test_training_gradient_checkpointing(self):
474pass
475
476@unittest.skip(
477reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
478)
479def test_training_gradient_checkpointing_use_reentrant(self):
480pass
481
482@unittest.skip(
483reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
484)
485def test_training_gradient_checkpointing_use_reentrant_false(self):
486pass
487
488@unittest.skip(reason="ChineseCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
489def test_save_load_fast_init_from_base(self):
490pass
491
492@unittest.skip(reason="ChineseCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
493def test_save_load_fast_init_to_base(self):
494pass
495
496@slow
497def test_model_from_pretrained(self):
498for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
499model = ChineseCLIPVisionModel.from_pretrained(model_name)
500self.assertIsNotNone(model)
501
502
503class ChineseCLIPModelTester:
504def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
505if text_kwargs is None:
506text_kwargs = {}
507if vision_kwargs is None:
508vision_kwargs = {}
509
510self.parent = parent
511self.text_model_tester = ChineseCLIPTextModelTester(parent, **text_kwargs)
512self.vision_model_tester = ChineseCLIPVisionModelTester(parent, **vision_kwargs)
513self.is_training = is_training
514
515def prepare_config_and_inputs(self):
516(
517config,
518input_ids,
519token_type_ids,
520attention_mask,
521_,
522__,
523___,
524) = self.text_model_tester.prepare_config_and_inputs()
525vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
526
527config = self.get_config()
528
529return config, input_ids, token_type_ids, attention_mask, pixel_values
530
531def get_config(self):
532return ChineseCLIPConfig.from_text_vision_configs(
533self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
534)
535
536def create_and_check_model(self, config, input_ids, token_type_ids, attention_mask, pixel_values):
537model = ChineseCLIPModel(config).to(torch_device).eval()
538with torch.no_grad():
539result = model(input_ids, pixel_values, attention_mask, token_type_ids)
540self.parent.assertEqual(
541result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
542)
543self.parent.assertEqual(
544result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
545)
546
547def prepare_config_and_inputs_for_common(self):
548config_and_inputs = self.prepare_config_and_inputs()
549config, input_ids, token_type_ids, attention_mask, pixel_values = config_and_inputs
550inputs_dict = {
551"input_ids": input_ids,
552"token_type_ids": token_type_ids,
553"attention_mask": attention_mask,
554"pixel_values": pixel_values,
555"return_loss": True,
556}
557return config, inputs_dict
558
559
560@require_torch
561class ChineseCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
562all_model_classes = (ChineseCLIPModel,) if is_torch_available() else ()
563pipeline_model_mapping = {"feature-extraction": ChineseCLIPModel} if is_torch_available() else {}
564fx_compatible = False
565test_head_masking = False
566test_pruning = False
567test_resize_embeddings = False
568test_attention_outputs = False
569
570def setUp(self):
571text_kwargs = {"use_labels": False, "batch_size": 12}
572vision_kwargs = {"batch_size": 12}
573self.model_tester = ChineseCLIPModelTester(self, text_kwargs, vision_kwargs)
574
575def test_model(self):
576config_and_inputs = self.model_tester.prepare_config_and_inputs()
577self.model_tester.create_and_check_model(*config_and_inputs)
578
579@unittest.skip(reason="Hidden_states is tested in individual model tests")
580def test_hidden_states_output(self):
581pass
582
583@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
584def test_inputs_embeds(self):
585pass
586
587@unittest.skip(reason="Retain_grad is tested in individual model tests")
588def test_retain_grad_hidden_states_attentions(self):
589pass
590
591@unittest.skip(reason="ChineseCLIPModel does not have input/output embeddings")
592def test_model_common_attributes(self):
593pass
594
595# override as the `logit_scale` parameter initilization is different for CHINESE_CLIP
596def test_initialization(self):
597config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
598
599configs_no_init = _config_zero_init(config)
600for sub_config_key in ("vision_config", "text_config"):
601sub_config = getattr(configs_no_init, sub_config_key, {})
602setattr(configs_no_init, sub_config_key, _config_zero_init(sub_config))
603for model_class in self.all_model_classes:
604model = model_class(config=configs_no_init)
605for name, param in model.named_parameters():
606if param.requires_grad:
607# check if `logit_scale` is initilized as per the original implementation
608if name == "logit_scale":
609self.assertAlmostEqual(
610param.data.item(),
611np.log(1 / 0.07),
612delta=1e-3,
613msg=f"Parameter {name} of model {model_class} seems not properly initialized",
614)
615else:
616self.assertIn(
617((param.data.mean() * 1e9).round() / 1e9).item(),
618[0.0, 1.0],
619msg=f"Parameter {name} of model {model_class} seems not properly initialized",
620)
621
622def _create_and_check_torchscript(self, config, inputs_dict):
623if not self.test_torchscript:
624return
625
626configs_no_init = _config_zero_init(config) # To be sure we have no Nan
627configs_no_init.torchscript = True
628configs_no_init.return_dict = False
629for model_class in self.all_model_classes:
630model = model_class(config=configs_no_init)
631model.to(torch_device)
632model.eval()
633
634try:
635input_ids = inputs_dict["input_ids"]
636pixel_values = inputs_dict["pixel_values"] # CHINESE_CLIP needs pixel_values
637traced_model = torch.jit.trace(model, (input_ids, pixel_values))
638except RuntimeError:
639self.fail("Couldn't trace module.")
640
641with tempfile.TemporaryDirectory() as tmp_dir_name:
642pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
643
644try:
645torch.jit.save(traced_model, pt_file_name)
646except Exception:
647self.fail("Couldn't save module.")
648
649try:
650loaded_model = torch.jit.load(pt_file_name)
651except Exception:
652self.fail("Couldn't load module.")
653
654model.to(torch_device)
655model.eval()
656
657loaded_model.to(torch_device)
658loaded_model.eval()
659
660model_state_dict = model.state_dict()
661loaded_model_state_dict = loaded_model.state_dict()
662
663non_persistent_buffers = {}
664for key in loaded_model_state_dict.keys():
665if key not in model_state_dict.keys():
666non_persistent_buffers[key] = loaded_model_state_dict[key]
667
668loaded_model_state_dict = {
669key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
670}
671
672self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
673
674model_buffers = list(model.buffers())
675for non_persistent_buffer in non_persistent_buffers.values():
676found_buffer = False
677for i, model_buffer in enumerate(model_buffers):
678if torch.equal(non_persistent_buffer, model_buffer):
679found_buffer = True
680break
681
682self.assertTrue(found_buffer)
683model_buffers.pop(i)
684
685models_equal = True
686for layer_name, p1 in model_state_dict.items():
687p2 = loaded_model_state_dict[layer_name]
688if p1.data.ne(p2.data).sum() > 0:
689models_equal = False
690
691self.assertTrue(models_equal)
692
693@slow
694def test_model_from_pretrained(self):
695for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
696model = ChineseCLIPModel.from_pretrained(model_name)
697self.assertIsNotNone(model)
698
699
700# We will verify our results on an image of Pikachu
701def prepare_img():
702url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
703im = Image.open(requests.get(url, stream=True).raw)
704return im
705
706
707@require_vision
708@require_torch
709class ChineseCLIPModelIntegrationTest(unittest.TestCase):
710@slow
711def test_inference(self):
712model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
713model = ChineseCLIPModel.from_pretrained(model_name).to(torch_device)
714processor = ChineseCLIPProcessor.from_pretrained(model_name)
715
716image = prepare_img()
717inputs = processor(
718text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], images=image, padding=True, return_tensors="pt"
719).to(torch_device)
720
721# forward pass
722with torch.no_grad():
723outputs = model(**inputs)
724
725# verify the logits
726self.assertEqual(
727outputs.logits_per_image.shape,
728torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
729)
730self.assertEqual(
731outputs.logits_per_text.shape,
732torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
733)
734
735probs = outputs.logits_per_image.softmax(dim=1)
736expected_probs = torch.tensor([[1.2686e-03, 5.4499e-02, 6.7968e-04, 9.4355e-01]], device=torch_device)
737
738self.assertTrue(torch.allclose(probs, expected_probs, atol=5e-3))
739