otter

Форк
0
504 строки · 20.5 Кб
1
# coding=utf-8
2
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" PyTorch CLIP model with xformers modifications"""
16

17

18
from typing import Optional, Tuple, Union
19

20
import torch
21
import torch.utils.checkpoint
22
from torch import nn
23

24
from transformers.activations import ACT2FN
25
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
26
from transformers.modeling_utils import PreTrainedModel
27
from transformers.utils import (
28
    add_start_docstrings,
29
    add_start_docstrings_to_model_forward,
30
    logging,
31
    replace_return_docstrings,
32
)
33
from transformers.models.clip.configuration_clip import (
34
    CLIPConfig,
35
    CLIPTextConfig,
36
    CLIPVisionConfig,
37
)
38

39
import xformers.ops as xops
40

41

42
logger = logging.get_logger(__name__)
43

44
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
    "openai/clip-vit-base-patch32",
46
    # See all CLIP models at https://huggingface.co/models?filter=clip
47
]
48

49

50
class CLIPVisionEmbeddings(nn.Module):
51
    def __init__(self, config: CLIPVisionConfig):
52
        super().__init__()
53
        self.config = config
54
        self.embed_dim = config.hidden_size
55
        self.image_size = config.image_size
56
        self.patch_size = config.patch_size
57

58
        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
59

60
        self.patch_embedding = nn.Conv2d(
61
            in_channels=config.num_channels,
62
            out_channels=self.embed_dim,
63
            kernel_size=self.patch_size,
64
            stride=self.patch_size,
65
            bias=False,
66
        )
67

68
        self.num_patches = (self.image_size // self.patch_size) ** 2
69
        self.num_positions = self.num_patches + 1
70
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
71
        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
72

73
    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
74
        batch_size = pixel_values.shape[0]
75
        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, width, grid, grid]
76
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
77

78
        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
79
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
80
        embeddings = embeddings + self.position_embedding(self.position_ids)
81
        return embeddings
82

83

84
class CLIPAttention(nn.Module):
85
    """Multi-headed attention from 'Attention Is All You Need' paper with xformers implementation"""
86

87
    def __init__(self, config):
88
        super().__init__()
89
        self.config = config
90
        self.embed_dim = config.hidden_size
91
        self.num_heads = config.num_attention_heads
92
        self.head_dim = self.embed_dim // self.num_heads
93
        if self.head_dim * self.num_heads != self.embed_dim:
94
            raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
95
        self.scale = self.head_dim**-0.5
96
        self.dropout = config.attention_dropout
97

98
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
99
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
100
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
101
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
102

103
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
104
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
105

106
    def forward(
107
        self,
108
        hidden_states: torch.Tensor,
109
        attention_mask: Optional[torch.Tensor] = None,
110
        causal_attention_mask: Optional[torch.Tensor] = None,
111
        output_attentions: Optional[bool] = False,
112
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
113
        """Input shape: Batch x Time x Channel"""
114

115
        bsz, tgt_len, embed_dim = hidden_states.size()
116

117
        # get query proj
118
        query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
119
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
120
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
121

122
        attn_weights_reshaped = None
123
        attn_output = xops.memory_efficient_attention(
124
            query_states,
125
            key_states,
126
            value_states,
127
            p=self.dropout,
128
            scale=self.scale,
129
        )
130
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
131

132
        attn_output = self.out_proj(attn_output)
133

134
        return attn_output, attn_weights_reshaped
135

136

137
class CLIPMLP(nn.Module):
138
    def __init__(self, config):
139
        super().__init__()
140
        self.config = config
141
        self.activation_fn = ACT2FN[config.hidden_act]
142
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
143
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
144

145
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
146
        hidden_states = self.fc1(hidden_states)
147
        hidden_states = self.activation_fn(hidden_states)
148
        hidden_states = self.fc2(hidden_states)
149
        return hidden_states
150

151

152
class CLIPEncoderLayer(nn.Module):
153
    def __init__(self, config: CLIPConfig):
154
        super().__init__()
155
        self.embed_dim = config.hidden_size
156
        self.self_attn = CLIPAttention(config)
157
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
158
        self.mlp = CLIPMLP(config)
159
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
160

161
    def forward(
162
        self,
163
        hidden_states: torch.Tensor,
164
        attention_mask: torch.Tensor,
165
        causal_attention_mask: torch.Tensor,
166
        output_attentions: Optional[bool] = False,
167
    ) -> Tuple[torch.FloatTensor]:
168
        """
169
        Args:
170
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
171
            attention_mask (`torch.FloatTensor`): attention mask of size
172
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
173
                `(config.encoder_attention_heads,)`.
174
            output_attentions (`bool`, *optional*):
175
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
176
                returned tensors for more detail.
177
        """
178
        residual = hidden_states
179

180
        hidden_states = self.layer_norm1(hidden_states)
181
        hidden_states, attn_weights = self.self_attn(
182
            hidden_states=hidden_states,
183
            attention_mask=attention_mask,
184
            causal_attention_mask=causal_attention_mask,
185
            output_attentions=output_attentions,
186
        )
187
        hidden_states = residual + hidden_states
188

189
        residual = hidden_states
190
        hidden_states = self.layer_norm2(hidden_states)
191
        hidden_states = self.mlp(hidden_states)
192
        hidden_states = residual + hidden_states
193

194
        outputs = (hidden_states,)
195

196
        if output_attentions:
197
            outputs += (attn_weights,)
198

199
        return outputs
200

201

202
class CLIPPreTrainedModel(PreTrainedModel):
203
    """
204
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
205
    models.
206
    """
207

208
    config_class = CLIPConfig
209
    base_model_prefix = "clip"
210
    supports_gradient_checkpointing = True
211
    _keys_to_ignore_on_load_missing = [r"position_ids"]
212

213
    def _set_gradient_checkpointing(self, module, value=False):
214
        if isinstance(module, CLIPEncoder):
215
            module.gradient_checkpointing = value
216

217

218
CLIP_START_DOCSTRING = r"""
219
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
220
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
221
    etc.)
222

223
    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
224
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
225
    and behavior.
226

227
    Parameters:
228
        config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
229
            Initializing with a config file does not load the weights associated with the model, only the
230
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
231
"""
232

233
CLIP_VISION_INPUTS_DOCSTRING = r"""
234
    Args:
235
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
236
            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
237
            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
238
        output_attentions (`bool`, *optional*):
239
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
240
            tensors for more detail.
241
        output_hidden_states (`bool`, *optional*):
242
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
243
            more detail.
244
        return_dict (`bool`, *optional*):
245
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
246
"""
247

248
CLIP_INPUTS_DOCSTRING = r"""
249
    Args:
250
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
251
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
252
            it.
253

254
            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
255
            [`PreTrainedTokenizer.__call__`] for details.
256

257
            [What are input IDs?](../glossary#input-ids)
258
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
259
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
260

261
            - 1 for tokens that are **not masked**,
262
            - 0 for tokens that are **masked**.
263

264
            [What are attention masks?](../glossary#attention-mask)
265
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
266
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
267
            config.max_position_embeddings - 1]`.
268

269
            [What are position IDs?](../glossary#position-ids)
270
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
271
            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
272
            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
273
        return_loss (`bool`, *optional*):
274
            Whether or not to return the contrastive loss.
275
        output_attentions (`bool`, *optional*):
276
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
277
            tensors for more detail.
278
        output_hidden_states (`bool`, *optional*):
279
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
280
            more detail.
281
        return_dict (`bool`, *optional*):
282
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
283
"""
284

285

286
class CLIPEncoder(nn.Module):
287
    """
288
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
289
    [`CLIPEncoderLayer`].
290

291
    Args:
292
        config: CLIPConfig
293
    """
294

295
    def __init__(self, config: CLIPConfig):
296
        super().__init__()
297
        self.config = config
298
        self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
299
        self.gradient_checkpointing = False
300

301
    def forward(
302
        self,
303
        inputs_embeds,
304
        attention_mask: Optional[torch.Tensor] = None,
305
        causal_attention_mask: Optional[torch.Tensor] = None,
306
        output_attentions: Optional[bool] = None,
307
        output_hidden_states: Optional[bool] = None,
308
        return_dict: Optional[bool] = None,
309
    ) -> Union[Tuple, BaseModelOutput]:
310
        r"""
311
        Args:
312
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
313
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
314
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
315
                than the model's internal embedding lookup matrix.
316
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
317
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
318

319
                - 1 for tokens that are **not masked**,
320
                - 0 for tokens that are **masked**.
321

322
                [What are attention masks?](../glossary#attention-mask)
323
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
324
                Causal mask for the text model. Mask values selected in `[0, 1]`:
325

326
                - 1 for tokens that are **not masked**,
327
                - 0 for tokens that are **masked**.
328

329
                [What are attention masks?](../glossary#attention-mask)
330
            output_attentions (`bool`, *optional*):
331
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
332
                returned tensors for more detail.
333
            output_hidden_states (`bool`, *optional*):
334
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
335
                for more detail.
336
            return_dict (`bool`, *optional*):
337
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
338
        """
339
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
340
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
341
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
342

343
        encoder_states = () if output_hidden_states else None
344
        all_attentions = () if output_attentions else None
345

346
        hidden_states = inputs_embeds
347
        for idx, encoder_layer in enumerate(self.layers):
348
            if output_hidden_states:
349
                encoder_states = encoder_states + (hidden_states,)
350
            if self.gradient_checkpointing and self.training:
351

352
                def create_custom_forward(module):
353
                    def custom_forward(*inputs):
354
                        return module(*inputs, output_attentions)
355

356
                    return custom_forward
357

358
                layer_outputs = torch.utils.checkpoint.checkpoint(
359
                    create_custom_forward(encoder_layer),
360
                    hidden_states,
361
                    attention_mask,
362
                    causal_attention_mask,
363
                )
364
            else:
365
                layer_outputs = encoder_layer(
366
                    hidden_states,
367
                    attention_mask,
368
                    causal_attention_mask,
369
                    output_attentions=output_attentions,
370
                )
371

372
            hidden_states = layer_outputs[0]
373

374
            if output_attentions:
375
                all_attentions = all_attentions + (layer_outputs[1],)
376

377
        if output_hidden_states:
378
            encoder_states = encoder_states + (hidden_states,)
379

380
        if not return_dict:
381
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
382
        return BaseModelOutput(
383
            last_hidden_state=hidden_states,
384
            hidden_states=encoder_states,
385
            attentions=all_attentions,
386
        )
387

388
    config_class = CLIPTextConfig
389

390
    _no_split_modules = ["CLIPEncoderLayer"]
391

392

393
class CLIPVisionTransformer(nn.Module):
394
    def __init__(self, config: CLIPVisionConfig):
395
        super().__init__()
396
        self.config = config
397
        embed_dim = config.hidden_size
398

399
        self.embeddings = CLIPVisionEmbeddings(config)
400
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
401
        self.encoder = CLIPEncoder(config)
402
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
403

404
    @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
405
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
406
    def forward(
407
        self,
408
        pixel_values: Optional[torch.FloatTensor] = None,
409
        output_attentions: Optional[bool] = None,
410
        output_hidden_states: Optional[bool] = None,
411
        return_dict: Optional[bool] = None,
412
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
413
        r"""
414
        Returns:
415

416
        """
417
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
418
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
419
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
420

421
        if pixel_values is None:
422
            raise ValueError("You have to specify pixel_values")
423

424
        hidden_states = self.embeddings(pixel_values)
425
        hidden_states = self.pre_layrnorm(hidden_states)
426

427
        encoder_outputs = self.encoder(
428
            inputs_embeds=hidden_states,
429
            output_attentions=output_attentions,
430
            output_hidden_states=output_hidden_states,
431
            return_dict=return_dict,
432
        )
433

434
        last_hidden_state = encoder_outputs[0]
435
        pooled_output = last_hidden_state[:, 0, :]
436
        pooled_output = self.post_layernorm(pooled_output)
437

438
        if not return_dict:
439
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
440

441
        return BaseModelOutputWithPooling(
442
            last_hidden_state=last_hidden_state,
443
            pooler_output=pooled_output,
444
            hidden_states=encoder_outputs.hidden_states,
445
            attentions=encoder_outputs.attentions,
446
        )
447

448

449
@add_start_docstrings(
450
    """The vision model from CLIP without any head or projection on top.""",
451
    CLIP_START_DOCSTRING,
452
)
453
class CLIPVisionModel(CLIPPreTrainedModel):
454
    config_class = CLIPVisionConfig
455
    main_input_name = "pixel_values"
456

457
    def __init__(self, config: CLIPVisionConfig):
458
        super().__init__(config)
459
        self.vision_model = CLIPVisionTransformer(config)
460
        # Initialize weights and apply final processing
461
        self.post_init()
462

463
    def get_input_embeddings(self) -> nn.Module:
464
        return self.vision_model.embeddings.patch_embedding
465

466
    @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
467
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
468
    def forward(
469
        self,
470
        pixel_values: Optional[torch.FloatTensor] = None,
471
        output_attentions: Optional[bool] = None,
472
        output_hidden_states: Optional[bool] = None,
473
        return_dict: Optional[bool] = None,
474
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
475
        r"""
476
        Returns:
477

478
        Examples:
479

480
        ```python
481
        >>> from PIL import Image
482
        >>> import requests
483
        >>> from transformers import AutoProcessor, CLIPVisionModel
484

485
        >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
486
        >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
487

488
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
489
        >>> image = Image.open(requests.get(url, stream=True).raw)
490

491
        >>> inputs = processor(images=image, return_tensors="pt")
492

493
        >>> outputs = model(**inputs)
494
        >>> last_hidden_state = outputs.last_hidden_state
495
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
496
        ```"""
497
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
498

499
        return self.vision_model(
500
            pixel_values=pixel_values,
501
            output_attentions=output_attentions,
502
            output_hidden_states=output_hidden_states,
503
            return_dict=return_dict,
504
        )
505

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.