transformers

Форк
0
/
test_modeling_owlvit.py 
896 строк · 33.5 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch OwlViT model. """
16

17

18
import inspect
19
import os
20
import tempfile
21
import unittest
22

23
import numpy as np
24
import requests
25

26
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
27
from transformers.testing_utils import (
28
    require_torch,
29
    require_torch_accelerator,
30
    require_torch_fp16,
31
    require_vision,
32
    slow,
33
    torch_device,
34
)
35
from transformers.utils import is_torch_available, is_vision_available
36

37
from ...test_configuration_common import ConfigTester
38
from ...test_modeling_common import (
39
    ModelTesterMixin,
40
    _config_zero_init,
41
    floats_tensor,
42
    ids_tensor,
43
    random_attention_mask,
44
)
45
from ...test_pipeline_mixin import PipelineTesterMixin
46

47

48
if is_torch_available():
49
    import torch
50
    from torch import nn
51

52
    from transformers import OwlViTForObjectDetection, OwlViTModel, OwlViTTextModel, OwlViTVisionModel
53
    from transformers.models.owlvit.modeling_owlvit import OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST
54

55

56
if is_vision_available():
57
    from PIL import Image
58

59
    from transformers import OwlViTProcessor
60

61

62
class OwlViTVisionModelTester:
63
    def __init__(
64
        self,
65
        parent,
66
        batch_size=12,
67
        image_size=32,
68
        patch_size=2,
69
        num_channels=3,
70
        is_training=True,
71
        hidden_size=32,
72
        num_hidden_layers=2,
73
        num_attention_heads=4,
74
        intermediate_size=37,
75
        dropout=0.1,
76
        attention_dropout=0.1,
77
        initializer_range=0.02,
78
        scope=None,
79
    ):
80
        self.parent = parent
81
        self.batch_size = batch_size
82
        self.image_size = image_size
83
        self.patch_size = patch_size
84
        self.num_channels = num_channels
85
        self.is_training = is_training
86
        self.hidden_size = hidden_size
87
        self.num_hidden_layers = num_hidden_layers
88
        self.num_attention_heads = num_attention_heads
89
        self.intermediate_size = intermediate_size
90
        self.dropout = dropout
91
        self.attention_dropout = attention_dropout
92
        self.initializer_range = initializer_range
93
        self.scope = scope
94

95
        # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
96
        num_patches = (image_size // patch_size) ** 2
97
        self.seq_length = num_patches + 1
98

99
    def prepare_config_and_inputs(self):
100
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
101
        config = self.get_config()
102

103
        return config, pixel_values
104

105
    def get_config(self):
106
        return OwlViTVisionConfig(
107
            image_size=self.image_size,
108
            patch_size=self.patch_size,
109
            num_channels=self.num_channels,
110
            hidden_size=self.hidden_size,
111
            num_hidden_layers=self.num_hidden_layers,
112
            num_attention_heads=self.num_attention_heads,
113
            intermediate_size=self.intermediate_size,
114
            dropout=self.dropout,
115
            attention_dropout=self.attention_dropout,
116
            initializer_range=self.initializer_range,
117
        )
118

119
    def create_and_check_model(self, config, pixel_values):
120
        model = OwlViTVisionModel(config=config).to(torch_device)
121
        model.eval()
122

123
        pixel_values = pixel_values.to(torch.float32)
124

125
        with torch.no_grad():
126
            result = model(pixel_values)
127
        # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
128
        num_patches = (self.image_size // self.patch_size) ** 2
129
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
130
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
131

132
    def prepare_config_and_inputs_for_common(self):
133
        config_and_inputs = self.prepare_config_and_inputs()
134
        config, pixel_values = config_and_inputs
135
        inputs_dict = {"pixel_values": pixel_values}
136
        return config, inputs_dict
137

138

139
@require_torch
140
class OwlViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
141
    """
142
    Here we also overwrite some of the tests of test_modeling_common.py, as OWLVIT does not use input_ids, inputs_embeds,
143
    attention_mask and seq_length.
144
    """
145

146
    all_model_classes = (OwlViTVisionModel,) if is_torch_available() else ()
147
    fx_compatible = False
148
    test_pruning = False
149
    test_resize_embeddings = False
150
    test_head_masking = False
151

152
    def setUp(self):
153
        self.model_tester = OwlViTVisionModelTester(self)
154
        self.config_tester = ConfigTester(
155
            self, config_class=OwlViTVisionConfig, has_text_modality=False, hidden_size=37
156
        )
157

158
    def test_config(self):
159
        self.config_tester.run_common_tests()
160

161
    @unittest.skip(reason="OWLVIT does not use inputs_embeds")
162
    def test_inputs_embeds(self):
163
        pass
164

165
    def test_model_common_attributes(self):
166
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
167

168
        for model_class in self.all_model_classes:
169
            model = model_class(config)
170
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
171
            x = model.get_output_embeddings()
172
            self.assertTrue(x is None or isinstance(x, nn.Linear))
173

174
    def test_forward_signature(self):
175
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
176

177
        for model_class in self.all_model_classes:
178
            model = model_class(config)
179
            signature = inspect.signature(model.forward)
180
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
181
            arg_names = [*signature.parameters.keys()]
182

183
            expected_arg_names = ["pixel_values"]
184
            self.assertListEqual(arg_names[:1], expected_arg_names)
185

186
    def test_model(self):
187
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
188
        self.model_tester.create_and_check_model(*config_and_inputs)
189

190
    @unittest.skip(reason="OWL-ViT does not support training yet")
191
    def test_training(self):
192
        pass
193

194
    @unittest.skip(reason="OWL-ViT does not support training yet")
195
    def test_training_gradient_checkpointing(self):
196
        pass
197

198
    @unittest.skip(
199
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
200
    )
201
    def test_training_gradient_checkpointing_use_reentrant(self):
202
        pass
203

204
    @unittest.skip(
205
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
206
    )
207
    def test_training_gradient_checkpointing_use_reentrant_false(self):
208
        pass
209

210
    @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")
211
    def test_save_load_fast_init_from_base(self):
212
        pass
213

214
    @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")
215
    def test_save_load_fast_init_to_base(self):
216
        pass
217

218
    @slow
219
    def test_model_from_pretrained(self):
220
        for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
221
            model = OwlViTVisionModel.from_pretrained(model_name)
222
            self.assertIsNotNone(model)
223

224

225
class OwlViTTextModelTester:
226
    def __init__(
227
        self,
228
        parent,
229
        batch_size=12,
230
        num_queries=4,
231
        seq_length=16,
232
        is_training=True,
233
        use_input_mask=True,
234
        use_labels=True,
235
        vocab_size=99,
236
        hidden_size=64,
237
        num_hidden_layers=12,
238
        num_attention_heads=4,
239
        intermediate_size=37,
240
        dropout=0.1,
241
        attention_dropout=0.1,
242
        max_position_embeddings=16,
243
        initializer_range=0.02,
244
        scope=None,
245
    ):
246
        self.parent = parent
247
        self.batch_size = batch_size
248
        self.num_queries = num_queries
249
        self.seq_length = seq_length
250
        self.is_training = is_training
251
        self.use_input_mask = use_input_mask
252
        self.use_labels = use_labels
253
        self.vocab_size = vocab_size
254
        self.hidden_size = hidden_size
255
        self.num_hidden_layers = num_hidden_layers
256
        self.num_attention_heads = num_attention_heads
257
        self.intermediate_size = intermediate_size
258
        self.dropout = dropout
259
        self.attention_dropout = attention_dropout
260
        self.max_position_embeddings = max_position_embeddings
261
        self.initializer_range = initializer_range
262
        self.scope = scope
263

264
    def prepare_config_and_inputs(self):
265
        input_ids = ids_tensor([self.batch_size * self.num_queries, self.seq_length], self.vocab_size)
266
        input_mask = None
267

268
        if self.use_input_mask:
269
            input_mask = random_attention_mask([self.batch_size * self.num_queries, self.seq_length])
270

271
        if input_mask is not None:
272
            num_text, seq_length = input_mask.shape
273

274
            rnd_start_indices = np.random.randint(1, seq_length - 1, size=(num_text,))
275
            for idx, start_index in enumerate(rnd_start_indices):
276
                input_mask[idx, :start_index] = 1
277
                input_mask[idx, start_index:] = 0
278

279
        config = self.get_config()
280

281
        return config, input_ids, input_mask
282

283
    def get_config(self):
284
        return OwlViTTextConfig(
285
            vocab_size=self.vocab_size,
286
            hidden_size=self.hidden_size,
287
            num_hidden_layers=self.num_hidden_layers,
288
            num_attention_heads=self.num_attention_heads,
289
            intermediate_size=self.intermediate_size,
290
            dropout=self.dropout,
291
            attention_dropout=self.attention_dropout,
292
            max_position_embeddings=self.max_position_embeddings,
293
            initializer_range=self.initializer_range,
294
        )
295

296
    def create_and_check_model(self, config, input_ids, input_mask):
297
        model = OwlViTTextModel(config=config).to(torch_device)
298
        model.eval()
299
        with torch.no_grad():
300
            result = model(input_ids=input_ids, attention_mask=input_mask)
301

302
        self.parent.assertEqual(
303
            result.last_hidden_state.shape, (self.batch_size * self.num_queries, self.seq_length, self.hidden_size)
304
        )
305
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size * self.num_queries, self.hidden_size))
306

307
    def prepare_config_and_inputs_for_common(self):
308
        config_and_inputs = self.prepare_config_and_inputs()
309
        config, input_ids, input_mask = config_and_inputs
310
        inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
311
        return config, inputs_dict
312

313

314
@require_torch
315
class OwlViTTextModelTest(ModelTesterMixin, unittest.TestCase):
316
    all_model_classes = (OwlViTTextModel,) if is_torch_available() else ()
317
    fx_compatible = False
318
    test_pruning = False
319
    test_head_masking = False
320

321
    def setUp(self):
322
        self.model_tester = OwlViTTextModelTester(self)
323
        self.config_tester = ConfigTester(self, config_class=OwlViTTextConfig, hidden_size=37)
324

325
    def test_config(self):
326
        self.config_tester.run_common_tests()
327

328
    def test_model(self):
329
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
330
        self.model_tester.create_and_check_model(*config_and_inputs)
331

332
    @unittest.skip(reason="OWL-ViT does not support training yet")
333
    def test_training(self):
334
        pass
335

336
    @unittest.skip(reason="OWL-ViT does not support training yet")
337
    def test_training_gradient_checkpointing(self):
338
        pass
339

340
    @unittest.skip(
341
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
342
    )
343
    def test_training_gradient_checkpointing_use_reentrant(self):
344
        pass
345

346
    @unittest.skip(
347
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
348
    )
349
    def test_training_gradient_checkpointing_use_reentrant_false(self):
350
        pass
351

352
    @unittest.skip(reason="OWLVIT does not use inputs_embeds")
353
    def test_inputs_embeds(self):
354
        pass
355

356
    @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")
357
    def test_save_load_fast_init_from_base(self):
358
        pass
359

360
    @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")
361
    def test_save_load_fast_init_to_base(self):
362
        pass
363

364
    @slow
365
    def test_model_from_pretrained(self):
366
        for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
367
            model = OwlViTTextModel.from_pretrained(model_name)
368
            self.assertIsNotNone(model)
369

370

371
class OwlViTModelTester:
372
    def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
373
        if text_kwargs is None:
374
            text_kwargs = {}
375
        if vision_kwargs is None:
376
            vision_kwargs = {}
377

378
        self.parent = parent
379
        self.text_model_tester = OwlViTTextModelTester(parent, **text_kwargs)
380
        self.vision_model_tester = OwlViTVisionModelTester(parent, **vision_kwargs)
381
        self.is_training = is_training
382
        self.text_config = self.text_model_tester.get_config().to_dict()
383
        self.vision_config = self.vision_model_tester.get_config().to_dict()
384

385
    def prepare_config_and_inputs(self):
386
        text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
387
        vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
388
        config = self.get_config()
389
        return config, input_ids, attention_mask, pixel_values
390

391
    def get_config(self):
392
        return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)
393

394
    def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
395
        model = OwlViTModel(config).to(torch_device).eval()
396

397
        with torch.no_grad():
398
            result = model(
399
                input_ids=input_ids,
400
                pixel_values=pixel_values,
401
                attention_mask=attention_mask,
402
            )
403

404
        image_logits_size = (
405
            self.vision_model_tester.batch_size,
406
            self.text_model_tester.batch_size * self.text_model_tester.num_queries,
407
        )
408
        text_logits_size = (
409
            self.text_model_tester.batch_size * self.text_model_tester.num_queries,
410
            self.vision_model_tester.batch_size,
411
        )
412
        self.parent.assertEqual(result.logits_per_image.shape, image_logits_size)
413
        self.parent.assertEqual(result.logits_per_text.shape, text_logits_size)
414

415
    def prepare_config_and_inputs_for_common(self):
416
        config_and_inputs = self.prepare_config_and_inputs()
417
        config, input_ids, attention_mask, pixel_values = config_and_inputs
418
        inputs_dict = {
419
            "pixel_values": pixel_values,
420
            "input_ids": input_ids,
421
            "attention_mask": attention_mask,
422
            "return_loss": False,
423
        }
424
        return config, inputs_dict
425

426

427
@require_torch
428
class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
429
    all_model_classes = (OwlViTModel,) if is_torch_available() else ()
430
    pipeline_model_mapping = (
431
        {
432
            "feature-extraction": OwlViTModel,
433
            "zero-shot-object-detection": OwlViTForObjectDetection,
434
        }
435
        if is_torch_available()
436
        else {}
437
    )
438
    fx_compatible = False
439
    test_head_masking = False
440
    test_pruning = False
441
    test_resize_embeddings = False
442
    test_attention_outputs = False
443

444
    def setUp(self):
445
        self.model_tester = OwlViTModelTester(self)
446

447
    def test_model(self):
448
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
449
        self.model_tester.create_and_check_model(*config_and_inputs)
450

451
    @unittest.skip(reason="Hidden_states is tested in individual model tests")
452
    def test_hidden_states_output(self):
453
        pass
454

455
    @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
456
    def test_inputs_embeds(self):
457
        pass
458

459
    @unittest.skip(reason="Retain_grad is tested in individual model tests")
460
    def test_retain_grad_hidden_states_attentions(self):
461
        pass
462

463
    @unittest.skip(reason="OwlViTModel does not have input/output embeddings")
464
    def test_model_common_attributes(self):
465
        pass
466

467
    # override as the `logit_scale` parameter initilization is different for OWLVIT
468
    def test_initialization(self):
469
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
470

471
        configs_no_init = _config_zero_init(config)
472
        for model_class in self.all_model_classes:
473
            model = model_class(config=configs_no_init)
474
            for name, param in model.named_parameters():
475
                if param.requires_grad:
476
                    # check if `logit_scale` is initilized as per the original implementation
477
                    if name == "logit_scale":
478
                        self.assertAlmostEqual(
479
                            param.data.item(),
480
                            np.log(1 / 0.07),
481
                            delta=1e-3,
482
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
483
                        )
484
                    else:
485
                        self.assertIn(
486
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
487
                            [0.0, 1.0],
488
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
489
                        )
490

491
    def _create_and_check_torchscript(self, config, inputs_dict):
492
        if not self.test_torchscript:
493
            return
494

495
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
496
        configs_no_init.torchscript = True
497
        configs_no_init.return_dict = False
498
        for model_class in self.all_model_classes:
499
            model = model_class(config=configs_no_init).to(torch_device)
500
            model.eval()
501

502
            try:
503
                input_ids = inputs_dict["input_ids"]
504
                pixel_values = inputs_dict["pixel_values"]  # OWLVIT needs pixel_values
505
                traced_model = torch.jit.trace(model, (input_ids, pixel_values))
506
            except RuntimeError:
507
                self.fail("Couldn't trace module.")
508

509
            with tempfile.TemporaryDirectory() as tmp_dir_name:
510
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
511

512
                try:
513
                    torch.jit.save(traced_model, pt_file_name)
514
                except Exception:
515
                    self.fail("Couldn't save module.")
516

517
                try:
518
                    loaded_model = torch.jit.load(pt_file_name)
519
                except Exception:
520
                    self.fail("Couldn't load module.")
521

522
            loaded_model = loaded_model.to(torch_device)
523
            loaded_model.eval()
524

525
            model_state_dict = model.state_dict()
526
            loaded_model_state_dict = loaded_model.state_dict()
527

528
            non_persistent_buffers = {}
529
            for key in loaded_model_state_dict.keys():
530
                if key not in model_state_dict.keys():
531
                    non_persistent_buffers[key] = loaded_model_state_dict[key]
532

533
            loaded_model_state_dict = {
534
                key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
535
            }
536

537
            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
538

539
            model_buffers = list(model.buffers())
540
            for non_persistent_buffer in non_persistent_buffers.values():
541
                found_buffer = False
542
                for i, model_buffer in enumerate(model_buffers):
543
                    if torch.equal(non_persistent_buffer, model_buffer):
544
                        found_buffer = True
545
                        break
546

547
                self.assertTrue(found_buffer)
548
                model_buffers.pop(i)
549

550
            models_equal = True
551
            for layer_name, p1 in model_state_dict.items():
552
                p2 = loaded_model_state_dict[layer_name]
553
                if p1.data.ne(p2.data).sum() > 0:
554
                    models_equal = False
555

556
            self.assertTrue(models_equal)
557

558
    def test_load_vision_text_config(self):
559
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
560

561
        # Save OwlViTConfig and check if we can load OwlViTVisionConfig from it
562
        with tempfile.TemporaryDirectory() as tmp_dir_name:
563
            config.save_pretrained(tmp_dir_name)
564
            vision_config = OwlViTVisionConfig.from_pretrained(tmp_dir_name)
565
            self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
566

567
        # Save OwlViTConfig and check if we can load OwlViTTextConfig from it
568
        with tempfile.TemporaryDirectory() as tmp_dir_name:
569
            config.save_pretrained(tmp_dir_name)
570
            text_config = OwlViTTextConfig.from_pretrained(tmp_dir_name)
571
            self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
572

573
    @slow
574
    def test_model_from_pretrained(self):
575
        for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
576
            model = OwlViTModel.from_pretrained(model_name)
577
            self.assertIsNotNone(model)
578

579

580
class OwlViTForObjectDetectionTester:
581
    def __init__(self, parent, is_training=True):
582
        self.parent = parent
583
        self.text_model_tester = OwlViTTextModelTester(parent)
584
        self.vision_model_tester = OwlViTVisionModelTester(parent)
585
        self.is_training = is_training
586
        self.text_config = self.text_model_tester.get_config().to_dict()
587
        self.vision_config = self.vision_model_tester.get_config().to_dict()
588

589
    def prepare_config_and_inputs(self):
590
        text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
591
        vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
592
        config = self.get_config()
593
        return config, pixel_values, input_ids, attention_mask
594

595
    def get_config(self):
596
        return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)
597

598
    def create_and_check_model(self, config, pixel_values, input_ids, attention_mask):
599
        model = OwlViTForObjectDetection(config).to(torch_device).eval()
600
        with torch.no_grad():
601
            result = model(
602
                pixel_values=pixel_values,
603
                input_ids=input_ids,
604
                attention_mask=attention_mask,
605
                return_dict=True,
606
            )
607

608
        pred_boxes_size = (
609
            self.vision_model_tester.batch_size,
610
            (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
611
            4,
612
        )
613
        pred_logits_size = (
614
            self.vision_model_tester.batch_size,
615
            (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
616
            4,
617
        )
618
        pred_class_embeds_size = (
619
            self.vision_model_tester.batch_size,
620
            (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
621
            self.text_model_tester.hidden_size,
622
        )
623
        self.parent.assertEqual(result.pred_boxes.shape, pred_boxes_size)
624
        self.parent.assertEqual(result.logits.shape, pred_logits_size)
625
        self.parent.assertEqual(result.class_embeds.shape, pred_class_embeds_size)
626

627
    def prepare_config_and_inputs_for_common(self):
628
        config_and_inputs = self.prepare_config_and_inputs()
629
        config, pixel_values, input_ids, attention_mask = config_and_inputs
630
        inputs_dict = {
631
            "pixel_values": pixel_values,
632
            "input_ids": input_ids,
633
            "attention_mask": attention_mask,
634
        }
635
        return config, inputs_dict
636

637

638
@require_torch
639
class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
640
    all_model_classes = (OwlViTForObjectDetection,) if is_torch_available() else ()
641
    fx_compatible = False
642
    test_head_masking = False
643
    test_pruning = False
644
    test_resize_embeddings = False
645
    test_attention_outputs = False
646

647
    def setUp(self):
648
        self.model_tester = OwlViTForObjectDetectionTester(self)
649

650
    def test_model(self):
651
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
652
        self.model_tester.create_and_check_model(*config_and_inputs)
653

654
    @unittest.skip(reason="Hidden_states is tested in individual model tests")
655
    def test_hidden_states_output(self):
656
        pass
657

658
    @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
659
    def test_inputs_embeds(self):
660
        pass
661

662
    @unittest.skip(reason="Retain_grad is tested in individual model tests")
663
    def test_retain_grad_hidden_states_attentions(self):
664
        pass
665

666
    @unittest.skip(reason="OwlViTModel does not have input/output embeddings")
667
    def test_model_common_attributes(self):
668
        pass
669

670
    @unittest.skip(reason="Test_initialization is tested in individual model tests")
671
    def test_initialization(self):
672
        pass
673

674
    @unittest.skip(reason="Test_forward_signature is tested in individual model tests")
675
    def test_forward_signature(self):
676
        pass
677

678
    @unittest.skip(reason="Test_save_load_fast_init_from_base is tested in individual model tests")
679
    def test_save_load_fast_init_from_base(self):
680
        pass
681

682
    @unittest.skip(reason="OWL-ViT does not support training yet")
683
    def test_training(self):
684
        pass
685

686
    @unittest.skip(reason="OWL-ViT does not support training yet")
687
    def test_training_gradient_checkpointing(self):
688
        pass
689

690
    @unittest.skip(
691
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
692
    )
693
    def test_training_gradient_checkpointing_use_reentrant(self):
694
        pass
695

696
    @unittest.skip(
697
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
698
    )
699
    def test_training_gradient_checkpointing_use_reentrant_false(self):
700
        pass
701

702
    def _create_and_check_torchscript(self, config, inputs_dict):
703
        if not self.test_torchscript:
704
            return
705

706
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
707
        configs_no_init.torchscript = True
708
        configs_no_init.return_dict = False
709
        for model_class in self.all_model_classes:
710
            model = model_class(config=configs_no_init).to(torch_device)
711
            model.eval()
712

713
            try:
714
                input_ids = inputs_dict["input_ids"]
715
                pixel_values = inputs_dict["pixel_values"]  # OWLVIT needs pixel_values
716
                traced_model = torch.jit.trace(model, (input_ids, pixel_values))
717
            except RuntimeError:
718
                self.fail("Couldn't trace module.")
719

720
            with tempfile.TemporaryDirectory() as tmp_dir_name:
721
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
722

723
                try:
724
                    torch.jit.save(traced_model, pt_file_name)
725
                except Exception:
726
                    self.fail("Couldn't save module.")
727

728
                try:
729
                    loaded_model = torch.jit.load(pt_file_name)
730
                except Exception:
731
                    self.fail("Couldn't load module.")
732

733
            loaded_model = loaded_model.to(torch_device)
734
            loaded_model.eval()
735

736
            model_state_dict = model.state_dict()
737
            loaded_model_state_dict = loaded_model.state_dict()
738

739
            non_persistent_buffers = {}
740
            for key in loaded_model_state_dict.keys():
741
                if key not in model_state_dict.keys():
742
                    non_persistent_buffers[key] = loaded_model_state_dict[key]
743

744
            loaded_model_state_dict = {
745
                key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
746
            }
747

748
            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
749

750
            model_buffers = list(model.buffers())
751
            for non_persistent_buffer in non_persistent_buffers.values():
752
                found_buffer = False
753
                for i, model_buffer in enumerate(model_buffers):
754
                    if torch.equal(non_persistent_buffer, model_buffer):
755
                        found_buffer = True
756
                        break
757

758
                self.assertTrue(found_buffer)
759
                model_buffers.pop(i)
760

761
            models_equal = True
762
            for layer_name, p1 in model_state_dict.items():
763
                p2 = loaded_model_state_dict[layer_name]
764
                if p1.data.ne(p2.data).sum() > 0:
765
                    models_equal = False
766

767
            self.assertTrue(models_equal)
768

769
    @slow
770
    def test_model_from_pretrained(self):
771
        for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
772
            model = OwlViTForObjectDetection.from_pretrained(model_name)
773
            self.assertIsNotNone(model)
774

775

776
# We will verify our results on an image of cute cats
777
def prepare_img():
778
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
779
    im = Image.open(requests.get(url, stream=True).raw)
780
    return im
781

782

783
@require_vision
784
@require_torch
785
class OwlViTModelIntegrationTest(unittest.TestCase):
786
    @slow
787
    def test_inference(self):
788
        model_name = "google/owlvit-base-patch32"
789
        model = OwlViTModel.from_pretrained(model_name).to(torch_device)
790
        processor = OwlViTProcessor.from_pretrained(model_name)
791

792
        image = prepare_img()
793
        inputs = processor(
794
            text=[["a photo of a cat", "a photo of a dog"]],
795
            images=image,
796
            max_length=16,
797
            padding="max_length",
798
            return_tensors="pt",
799
        ).to(torch_device)
800

801
        # forward pass
802
        with torch.no_grad():
803
            outputs = model(**inputs)
804

805
        # verify the logits
806
        self.assertEqual(
807
            outputs.logits_per_image.shape,
808
            torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
809
        )
810
        self.assertEqual(
811
            outputs.logits_per_text.shape,
812
            torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
813
        )
814
        expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
815
        self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
816

817
    @slow
818
    def test_inference_object_detection(self):
819
        model_name = "google/owlvit-base-patch32"
820
        model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
821

822
        processor = OwlViTProcessor.from_pretrained(model_name)
823

824
        image = prepare_img()
825
        inputs = processor(
826
            text=[["a photo of a cat", "a photo of a dog"]],
827
            images=image,
828
            max_length=16,
829
            padding="max_length",
830
            return_tensors="pt",
831
        ).to(torch_device)
832

833
        with torch.no_grad():
834
            outputs = model(**inputs)
835

836
        num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
837
        self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
838

839
        expected_slice_boxes = torch.tensor(
840
            [[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
841
        ).to(torch_device)
842
        self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
843

844
    @slow
845
    def test_inference_one_shot_object_detection(self):
846
        model_name = "google/owlvit-base-patch32"
847
        model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
848

849
        processor = OwlViTProcessor.from_pretrained(model_name)
850

851
        image = prepare_img()
852
        query_image = prepare_img()
853
        inputs = processor(
854
            images=image,
855
            query_images=query_image,
856
            max_length=16,
857
            padding="max_length",
858
            return_tensors="pt",
859
        ).to(torch_device)
860

861
        with torch.no_grad():
862
            outputs = model.image_guided_detection(**inputs)
863

864
        num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
865
        self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
866

867
        expected_slice_boxes = torch.tensor(
868
            [[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
869
        ).to(torch_device)
870
        self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
871

872
    @slow
873
    @require_torch_accelerator
874
    @require_torch_fp16
875
    def test_inference_one_shot_object_detection_fp16(self):
876
        model_name = "google/owlvit-base-patch32"
877
        model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
878

879
        processor = OwlViTProcessor.from_pretrained(model_name)
880

881
        image = prepare_img()
882
        query_image = prepare_img()
883
        inputs = processor(
884
            images=image,
885
            query_images=query_image,
886
            max_length=16,
887
            padding="max_length",
888
            return_tensors="pt",
889
        ).to(torch_device)
890

891
        with torch.no_grad():
892
            outputs = model.image_guided_detection(**inputs)
893

894
        # No need to check the logits, we just check inference runs fine.
895
        num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
896
        self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
897

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.