transformers
896 строк · 33.5 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch OwlViT model. """
16
17
18import inspect19import os20import tempfile21import unittest22
23import numpy as np24import requests25
26from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig27from transformers.testing_utils import (28require_torch,29require_torch_accelerator,30require_torch_fp16,31require_vision,32slow,33torch_device,34)
35from transformers.utils import is_torch_available, is_vision_available36
37from ...test_configuration_common import ConfigTester38from ...test_modeling_common import (39ModelTesterMixin,40_config_zero_init,41floats_tensor,42ids_tensor,43random_attention_mask,44)
45from ...test_pipeline_mixin import PipelineTesterMixin46
47
48if is_torch_available():49import torch50from torch import nn51
52from transformers import OwlViTForObjectDetection, OwlViTModel, OwlViTTextModel, OwlViTVisionModel53from transformers.models.owlvit.modeling_owlvit import OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST54
55
56if is_vision_available():57from PIL import Image58
59from transformers import OwlViTProcessor60
61
62class OwlViTVisionModelTester:63def __init__(64self,65parent,66batch_size=12,67image_size=32,68patch_size=2,69num_channels=3,70is_training=True,71hidden_size=32,72num_hidden_layers=2,73num_attention_heads=4,74intermediate_size=37,75dropout=0.1,76attention_dropout=0.1,77initializer_range=0.02,78scope=None,79):80self.parent = parent81self.batch_size = batch_size82self.image_size = image_size83self.patch_size = patch_size84self.num_channels = num_channels85self.is_training = is_training86self.hidden_size = hidden_size87self.num_hidden_layers = num_hidden_layers88self.num_attention_heads = num_attention_heads89self.intermediate_size = intermediate_size90self.dropout = dropout91self.attention_dropout = attention_dropout92self.initializer_range = initializer_range93self.scope = scope94
95# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)96num_patches = (image_size // patch_size) ** 297self.seq_length = num_patches + 198
99def prepare_config_and_inputs(self):100pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])101config = self.get_config()102
103return config, pixel_values104
105def get_config(self):106return OwlViTVisionConfig(107image_size=self.image_size,108patch_size=self.patch_size,109num_channels=self.num_channels,110hidden_size=self.hidden_size,111num_hidden_layers=self.num_hidden_layers,112num_attention_heads=self.num_attention_heads,113intermediate_size=self.intermediate_size,114dropout=self.dropout,115attention_dropout=self.attention_dropout,116initializer_range=self.initializer_range,117)118
119def create_and_check_model(self, config, pixel_values):120model = OwlViTVisionModel(config=config).to(torch_device)121model.eval()122
123pixel_values = pixel_values.to(torch.float32)124
125with torch.no_grad():126result = model(pixel_values)127# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)128num_patches = (self.image_size // self.patch_size) ** 2129self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))130self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))131
132def prepare_config_and_inputs_for_common(self):133config_and_inputs = self.prepare_config_and_inputs()134config, pixel_values = config_and_inputs135inputs_dict = {"pixel_values": pixel_values}136return config, inputs_dict137
138
139@require_torch
140class OwlViTVisionModelTest(ModelTesterMixin, unittest.TestCase):141"""142Here we also overwrite some of the tests of test_modeling_common.py, as OWLVIT does not use input_ids, inputs_embeds,
143attention_mask and seq_length.
144"""
145
146all_model_classes = (OwlViTVisionModel,) if is_torch_available() else ()147fx_compatible = False148test_pruning = False149test_resize_embeddings = False150test_head_masking = False151
152def setUp(self):153self.model_tester = OwlViTVisionModelTester(self)154self.config_tester = ConfigTester(155self, config_class=OwlViTVisionConfig, has_text_modality=False, hidden_size=37156)157
158def test_config(self):159self.config_tester.run_common_tests()160
161@unittest.skip(reason="OWLVIT does not use inputs_embeds")162def test_inputs_embeds(self):163pass164
165def test_model_common_attributes(self):166config, _ = self.model_tester.prepare_config_and_inputs_for_common()167
168for model_class in self.all_model_classes:169model = model_class(config)170self.assertIsInstance(model.get_input_embeddings(), (nn.Module))171x = model.get_output_embeddings()172self.assertTrue(x is None or isinstance(x, nn.Linear))173
174def test_forward_signature(self):175config, _ = self.model_tester.prepare_config_and_inputs_for_common()176
177for model_class in self.all_model_classes:178model = model_class(config)179signature = inspect.signature(model.forward)180# signature.parameters is an OrderedDict => so arg_names order is deterministic181arg_names = [*signature.parameters.keys()]182
183expected_arg_names = ["pixel_values"]184self.assertListEqual(arg_names[:1], expected_arg_names)185
186def test_model(self):187config_and_inputs = self.model_tester.prepare_config_and_inputs()188self.model_tester.create_and_check_model(*config_and_inputs)189
190@unittest.skip(reason="OWL-ViT does not support training yet")191def test_training(self):192pass193
194@unittest.skip(reason="OWL-ViT does not support training yet")195def test_training_gradient_checkpointing(self):196pass197
198@unittest.skip(199reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"200)201def test_training_gradient_checkpointing_use_reentrant(self):202pass203
204@unittest.skip(205reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"206)207def test_training_gradient_checkpointing_use_reentrant_false(self):208pass209
210@unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")211def test_save_load_fast_init_from_base(self):212pass213
214@unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")215def test_save_load_fast_init_to_base(self):216pass217
218@slow219def test_model_from_pretrained(self):220for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:221model = OwlViTVisionModel.from_pretrained(model_name)222self.assertIsNotNone(model)223
224
225class OwlViTTextModelTester:226def __init__(227self,228parent,229batch_size=12,230num_queries=4,231seq_length=16,232is_training=True,233use_input_mask=True,234use_labels=True,235vocab_size=99,236hidden_size=64,237num_hidden_layers=12,238num_attention_heads=4,239intermediate_size=37,240dropout=0.1,241attention_dropout=0.1,242max_position_embeddings=16,243initializer_range=0.02,244scope=None,245):246self.parent = parent247self.batch_size = batch_size248self.num_queries = num_queries249self.seq_length = seq_length250self.is_training = is_training251self.use_input_mask = use_input_mask252self.use_labels = use_labels253self.vocab_size = vocab_size254self.hidden_size = hidden_size255self.num_hidden_layers = num_hidden_layers256self.num_attention_heads = num_attention_heads257self.intermediate_size = intermediate_size258self.dropout = dropout259self.attention_dropout = attention_dropout260self.max_position_embeddings = max_position_embeddings261self.initializer_range = initializer_range262self.scope = scope263
264def prepare_config_and_inputs(self):265input_ids = ids_tensor([self.batch_size * self.num_queries, self.seq_length], self.vocab_size)266input_mask = None267
268if self.use_input_mask:269input_mask = random_attention_mask([self.batch_size * self.num_queries, self.seq_length])270
271if input_mask is not None:272num_text, seq_length = input_mask.shape273
274rnd_start_indices = np.random.randint(1, seq_length - 1, size=(num_text,))275for idx, start_index in enumerate(rnd_start_indices):276input_mask[idx, :start_index] = 1277input_mask[idx, start_index:] = 0278
279config = self.get_config()280
281return config, input_ids, input_mask282
283def get_config(self):284return OwlViTTextConfig(285vocab_size=self.vocab_size,286hidden_size=self.hidden_size,287num_hidden_layers=self.num_hidden_layers,288num_attention_heads=self.num_attention_heads,289intermediate_size=self.intermediate_size,290dropout=self.dropout,291attention_dropout=self.attention_dropout,292max_position_embeddings=self.max_position_embeddings,293initializer_range=self.initializer_range,294)295
296def create_and_check_model(self, config, input_ids, input_mask):297model = OwlViTTextModel(config=config).to(torch_device)298model.eval()299with torch.no_grad():300result = model(input_ids=input_ids, attention_mask=input_mask)301
302self.parent.assertEqual(303result.last_hidden_state.shape, (self.batch_size * self.num_queries, self.seq_length, self.hidden_size)304)305self.parent.assertEqual(result.pooler_output.shape, (self.batch_size * self.num_queries, self.hidden_size))306
307def prepare_config_and_inputs_for_common(self):308config_and_inputs = self.prepare_config_and_inputs()309config, input_ids, input_mask = config_and_inputs310inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}311return config, inputs_dict312
313
314@require_torch
315class OwlViTTextModelTest(ModelTesterMixin, unittest.TestCase):316all_model_classes = (OwlViTTextModel,) if is_torch_available() else ()317fx_compatible = False318test_pruning = False319test_head_masking = False320
321def setUp(self):322self.model_tester = OwlViTTextModelTester(self)323self.config_tester = ConfigTester(self, config_class=OwlViTTextConfig, hidden_size=37)324
325def test_config(self):326self.config_tester.run_common_tests()327
328def test_model(self):329config_and_inputs = self.model_tester.prepare_config_and_inputs()330self.model_tester.create_and_check_model(*config_and_inputs)331
332@unittest.skip(reason="OWL-ViT does not support training yet")333def test_training(self):334pass335
336@unittest.skip(reason="OWL-ViT does not support training yet")337def test_training_gradient_checkpointing(self):338pass339
340@unittest.skip(341reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"342)343def test_training_gradient_checkpointing_use_reentrant(self):344pass345
346@unittest.skip(347reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"348)349def test_training_gradient_checkpointing_use_reentrant_false(self):350pass351
352@unittest.skip(reason="OWLVIT does not use inputs_embeds")353def test_inputs_embeds(self):354pass355
356@unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")357def test_save_load_fast_init_from_base(self):358pass359
360@unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")361def test_save_load_fast_init_to_base(self):362pass363
364@slow365def test_model_from_pretrained(self):366for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:367model = OwlViTTextModel.from_pretrained(model_name)368self.assertIsNotNone(model)369
370
371class OwlViTModelTester:372def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):373if text_kwargs is None:374text_kwargs = {}375if vision_kwargs is None:376vision_kwargs = {}377
378self.parent = parent379self.text_model_tester = OwlViTTextModelTester(parent, **text_kwargs)380self.vision_model_tester = OwlViTVisionModelTester(parent, **vision_kwargs)381self.is_training = is_training382self.text_config = self.text_model_tester.get_config().to_dict()383self.vision_config = self.vision_model_tester.get_config().to_dict()384
385def prepare_config_and_inputs(self):386text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()387vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()388config = self.get_config()389return config, input_ids, attention_mask, pixel_values390
391def get_config(self):392return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)393
394def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):395model = OwlViTModel(config).to(torch_device).eval()396
397with torch.no_grad():398result = model(399input_ids=input_ids,400pixel_values=pixel_values,401attention_mask=attention_mask,402)403
404image_logits_size = (405self.vision_model_tester.batch_size,406self.text_model_tester.batch_size * self.text_model_tester.num_queries,407)408text_logits_size = (409self.text_model_tester.batch_size * self.text_model_tester.num_queries,410self.vision_model_tester.batch_size,411)412self.parent.assertEqual(result.logits_per_image.shape, image_logits_size)413self.parent.assertEqual(result.logits_per_text.shape, text_logits_size)414
415def prepare_config_and_inputs_for_common(self):416config_and_inputs = self.prepare_config_and_inputs()417config, input_ids, attention_mask, pixel_values = config_and_inputs418inputs_dict = {419"pixel_values": pixel_values,420"input_ids": input_ids,421"attention_mask": attention_mask,422"return_loss": False,423}424return config, inputs_dict425
426
427@require_torch
428class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):429all_model_classes = (OwlViTModel,) if is_torch_available() else ()430pipeline_model_mapping = (431{432"feature-extraction": OwlViTModel,433"zero-shot-object-detection": OwlViTForObjectDetection,434}435if is_torch_available()436else {}437)438fx_compatible = False439test_head_masking = False440test_pruning = False441test_resize_embeddings = False442test_attention_outputs = False443
444def setUp(self):445self.model_tester = OwlViTModelTester(self)446
447def test_model(self):448config_and_inputs = self.model_tester.prepare_config_and_inputs()449self.model_tester.create_and_check_model(*config_and_inputs)450
451@unittest.skip(reason="Hidden_states is tested in individual model tests")452def test_hidden_states_output(self):453pass454
455@unittest.skip(reason="Inputs_embeds is tested in individual model tests")456def test_inputs_embeds(self):457pass458
459@unittest.skip(reason="Retain_grad is tested in individual model tests")460def test_retain_grad_hidden_states_attentions(self):461pass462
463@unittest.skip(reason="OwlViTModel does not have input/output embeddings")464def test_model_common_attributes(self):465pass466
467# override as the `logit_scale` parameter initilization is different for OWLVIT468def test_initialization(self):469config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()470
471configs_no_init = _config_zero_init(config)472for model_class in self.all_model_classes:473model = model_class(config=configs_no_init)474for name, param in model.named_parameters():475if param.requires_grad:476# check if `logit_scale` is initilized as per the original implementation477if name == "logit_scale":478self.assertAlmostEqual(479param.data.item(),480np.log(1 / 0.07),481delta=1e-3,482msg=f"Parameter {name} of model {model_class} seems not properly initialized",483)484else:485self.assertIn(486((param.data.mean() * 1e9).round() / 1e9).item(),487[0.0, 1.0],488msg=f"Parameter {name} of model {model_class} seems not properly initialized",489)490
491def _create_and_check_torchscript(self, config, inputs_dict):492if not self.test_torchscript:493return494
495configs_no_init = _config_zero_init(config) # To be sure we have no Nan496configs_no_init.torchscript = True497configs_no_init.return_dict = False498for model_class in self.all_model_classes:499model = model_class(config=configs_no_init).to(torch_device)500model.eval()501
502try:503input_ids = inputs_dict["input_ids"]504pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values505traced_model = torch.jit.trace(model, (input_ids, pixel_values))506except RuntimeError:507self.fail("Couldn't trace module.")508
509with tempfile.TemporaryDirectory() as tmp_dir_name:510pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")511
512try:513torch.jit.save(traced_model, pt_file_name)514except Exception:515self.fail("Couldn't save module.")516
517try:518loaded_model = torch.jit.load(pt_file_name)519except Exception:520self.fail("Couldn't load module.")521
522loaded_model = loaded_model.to(torch_device)523loaded_model.eval()524
525model_state_dict = model.state_dict()526loaded_model_state_dict = loaded_model.state_dict()527
528non_persistent_buffers = {}529for key in loaded_model_state_dict.keys():530if key not in model_state_dict.keys():531non_persistent_buffers[key] = loaded_model_state_dict[key]532
533loaded_model_state_dict = {534key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers535}536
537self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))538
539model_buffers = list(model.buffers())540for non_persistent_buffer in non_persistent_buffers.values():541found_buffer = False542for i, model_buffer in enumerate(model_buffers):543if torch.equal(non_persistent_buffer, model_buffer):544found_buffer = True545break546
547self.assertTrue(found_buffer)548model_buffers.pop(i)549
550models_equal = True551for layer_name, p1 in model_state_dict.items():552p2 = loaded_model_state_dict[layer_name]553if p1.data.ne(p2.data).sum() > 0:554models_equal = False555
556self.assertTrue(models_equal)557
558def test_load_vision_text_config(self):559config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()560
561# Save OwlViTConfig and check if we can load OwlViTVisionConfig from it562with tempfile.TemporaryDirectory() as tmp_dir_name:563config.save_pretrained(tmp_dir_name)564vision_config = OwlViTVisionConfig.from_pretrained(tmp_dir_name)565self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())566
567# Save OwlViTConfig and check if we can load OwlViTTextConfig from it568with tempfile.TemporaryDirectory() as tmp_dir_name:569config.save_pretrained(tmp_dir_name)570text_config = OwlViTTextConfig.from_pretrained(tmp_dir_name)571self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())572
573@slow574def test_model_from_pretrained(self):575for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:576model = OwlViTModel.from_pretrained(model_name)577self.assertIsNotNone(model)578
579
580class OwlViTForObjectDetectionTester:581def __init__(self, parent, is_training=True):582self.parent = parent583self.text_model_tester = OwlViTTextModelTester(parent)584self.vision_model_tester = OwlViTVisionModelTester(parent)585self.is_training = is_training586self.text_config = self.text_model_tester.get_config().to_dict()587self.vision_config = self.vision_model_tester.get_config().to_dict()588
589def prepare_config_and_inputs(self):590text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()591vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()592config = self.get_config()593return config, pixel_values, input_ids, attention_mask594
595def get_config(self):596return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)597
598def create_and_check_model(self, config, pixel_values, input_ids, attention_mask):599model = OwlViTForObjectDetection(config).to(torch_device).eval()600with torch.no_grad():601result = model(602pixel_values=pixel_values,603input_ids=input_ids,604attention_mask=attention_mask,605return_dict=True,606)607
608pred_boxes_size = (609self.vision_model_tester.batch_size,610(self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,6114,612)613pred_logits_size = (614self.vision_model_tester.batch_size,615(self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,6164,617)618pred_class_embeds_size = (619self.vision_model_tester.batch_size,620(self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,621self.text_model_tester.hidden_size,622)623self.parent.assertEqual(result.pred_boxes.shape, pred_boxes_size)624self.parent.assertEqual(result.logits.shape, pred_logits_size)625self.parent.assertEqual(result.class_embeds.shape, pred_class_embeds_size)626
627def prepare_config_and_inputs_for_common(self):628config_and_inputs = self.prepare_config_and_inputs()629config, pixel_values, input_ids, attention_mask = config_and_inputs630inputs_dict = {631"pixel_values": pixel_values,632"input_ids": input_ids,633"attention_mask": attention_mask,634}635return config, inputs_dict636
637
638@require_torch
639class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):640all_model_classes = (OwlViTForObjectDetection,) if is_torch_available() else ()641fx_compatible = False642test_head_masking = False643test_pruning = False644test_resize_embeddings = False645test_attention_outputs = False646
647def setUp(self):648self.model_tester = OwlViTForObjectDetectionTester(self)649
650def test_model(self):651config_and_inputs = self.model_tester.prepare_config_and_inputs()652self.model_tester.create_and_check_model(*config_and_inputs)653
654@unittest.skip(reason="Hidden_states is tested in individual model tests")655def test_hidden_states_output(self):656pass657
658@unittest.skip(reason="Inputs_embeds is tested in individual model tests")659def test_inputs_embeds(self):660pass661
662@unittest.skip(reason="Retain_grad is tested in individual model tests")663def test_retain_grad_hidden_states_attentions(self):664pass665
666@unittest.skip(reason="OwlViTModel does not have input/output embeddings")667def test_model_common_attributes(self):668pass669
670@unittest.skip(reason="Test_initialization is tested in individual model tests")671def test_initialization(self):672pass673
674@unittest.skip(reason="Test_forward_signature is tested in individual model tests")675def test_forward_signature(self):676pass677
678@unittest.skip(reason="Test_save_load_fast_init_from_base is tested in individual model tests")679def test_save_load_fast_init_from_base(self):680pass681
682@unittest.skip(reason="OWL-ViT does not support training yet")683def test_training(self):684pass685
686@unittest.skip(reason="OWL-ViT does not support training yet")687def test_training_gradient_checkpointing(self):688pass689
690@unittest.skip(691reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"692)693def test_training_gradient_checkpointing_use_reentrant(self):694pass695
696@unittest.skip(697reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"698)699def test_training_gradient_checkpointing_use_reentrant_false(self):700pass701
702def _create_and_check_torchscript(self, config, inputs_dict):703if not self.test_torchscript:704return705
706configs_no_init = _config_zero_init(config) # To be sure we have no Nan707configs_no_init.torchscript = True708configs_no_init.return_dict = False709for model_class in self.all_model_classes:710model = model_class(config=configs_no_init).to(torch_device)711model.eval()712
713try:714input_ids = inputs_dict["input_ids"]715pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values716traced_model = torch.jit.trace(model, (input_ids, pixel_values))717except RuntimeError:718self.fail("Couldn't trace module.")719
720with tempfile.TemporaryDirectory() as tmp_dir_name:721pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")722
723try:724torch.jit.save(traced_model, pt_file_name)725except Exception:726self.fail("Couldn't save module.")727
728try:729loaded_model = torch.jit.load(pt_file_name)730except Exception:731self.fail("Couldn't load module.")732
733loaded_model = loaded_model.to(torch_device)734loaded_model.eval()735
736model_state_dict = model.state_dict()737loaded_model_state_dict = loaded_model.state_dict()738
739non_persistent_buffers = {}740for key in loaded_model_state_dict.keys():741if key not in model_state_dict.keys():742non_persistent_buffers[key] = loaded_model_state_dict[key]743
744loaded_model_state_dict = {745key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers746}747
748self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))749
750model_buffers = list(model.buffers())751for non_persistent_buffer in non_persistent_buffers.values():752found_buffer = False753for i, model_buffer in enumerate(model_buffers):754if torch.equal(non_persistent_buffer, model_buffer):755found_buffer = True756break757
758self.assertTrue(found_buffer)759model_buffers.pop(i)760
761models_equal = True762for layer_name, p1 in model_state_dict.items():763p2 = loaded_model_state_dict[layer_name]764if p1.data.ne(p2.data).sum() > 0:765models_equal = False766
767self.assertTrue(models_equal)768
769@slow770def test_model_from_pretrained(self):771for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:772model = OwlViTForObjectDetection.from_pretrained(model_name)773self.assertIsNotNone(model)774
775
776# We will verify our results on an image of cute cats
777def prepare_img():778url = "http://images.cocodataset.org/val2017/000000039769.jpg"779im = Image.open(requests.get(url, stream=True).raw)780return im781
782
783@require_vision
784@require_torch
785class OwlViTModelIntegrationTest(unittest.TestCase):786@slow787def test_inference(self):788model_name = "google/owlvit-base-patch32"789model = OwlViTModel.from_pretrained(model_name).to(torch_device)790processor = OwlViTProcessor.from_pretrained(model_name)791
792image = prepare_img()793inputs = processor(794text=[["a photo of a cat", "a photo of a dog"]],795images=image,796max_length=16,797padding="max_length",798return_tensors="pt",799).to(torch_device)800
801# forward pass802with torch.no_grad():803outputs = model(**inputs)804
805# verify the logits806self.assertEqual(807outputs.logits_per_image.shape,808torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),809)810self.assertEqual(811outputs.logits_per_text.shape,812torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),813)814expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)815self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))816
817@slow818def test_inference_object_detection(self):819model_name = "google/owlvit-base-patch32"820model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)821
822processor = OwlViTProcessor.from_pretrained(model_name)823
824image = prepare_img()825inputs = processor(826text=[["a photo of a cat", "a photo of a dog"]],827images=image,828max_length=16,829padding="max_length",830return_tensors="pt",831).to(torch_device)832
833with torch.no_grad():834outputs = model(**inputs)835
836num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)837self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))838
839expected_slice_boxes = torch.tensor(840[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]841).to(torch_device)842self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))843
844@slow845def test_inference_one_shot_object_detection(self):846model_name = "google/owlvit-base-patch32"847model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)848
849processor = OwlViTProcessor.from_pretrained(model_name)850
851image = prepare_img()852query_image = prepare_img()853inputs = processor(854images=image,855query_images=query_image,856max_length=16,857padding="max_length",858return_tensors="pt",859).to(torch_device)860
861with torch.no_grad():862outputs = model.image_guided_detection(**inputs)863
864num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)865self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))866
867expected_slice_boxes = torch.tensor(868[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]869).to(torch_device)870self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))871
872@slow873@require_torch_accelerator874@require_torch_fp16875def test_inference_one_shot_object_detection_fp16(self):876model_name = "google/owlvit-base-patch32"877model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)878
879processor = OwlViTProcessor.from_pretrained(model_name)880
881image = prepare_img()882query_image = prepare_img()883inputs = processor(884images=image,885query_images=query_image,886max_length=16,887padding="max_length",888return_tensors="pt",889).to(torch_device)890
891with torch.no_grad():892outputs = model.image_guided_detection(**inputs)893
894# No need to check the logits, we just check inference runs fine.895num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)896self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))897