otter
504 строки · 20.5 Кб
1# coding=utf-8
2# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" PyTorch CLIP model with xformers modifications"""
16
17
18from typing import Optional, Tuple, Union19
20import torch21import torch.utils.checkpoint22from torch import nn23
24from transformers.activations import ACT2FN25from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling26from transformers.modeling_utils import PreTrainedModel27from transformers.utils import (28add_start_docstrings,29add_start_docstrings_to_model_forward,30logging,31replace_return_docstrings,32)
33from transformers.models.clip.configuration_clip import (34CLIPConfig,35CLIPTextConfig,36CLIPVisionConfig,37)
38
39import xformers.ops as xops40
41
42logger = logging.get_logger(__name__)43
44CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [45"openai/clip-vit-base-patch32",46# See all CLIP models at https://huggingface.co/models?filter=clip47]
48
49
50class CLIPVisionEmbeddings(nn.Module):51def __init__(self, config: CLIPVisionConfig):52super().__init__()53self.config = config54self.embed_dim = config.hidden_size55self.image_size = config.image_size56self.patch_size = config.patch_size57
58self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))59
60self.patch_embedding = nn.Conv2d(61in_channels=config.num_channels,62out_channels=self.embed_dim,63kernel_size=self.patch_size,64stride=self.patch_size,65bias=False,66)67
68self.num_patches = (self.image_size // self.patch_size) ** 269self.num_positions = self.num_patches + 170self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)71self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))72
73def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:74batch_size = pixel_values.shape[0]75patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]76patch_embeds = patch_embeds.flatten(2).transpose(1, 2)77
78class_embeds = self.class_embedding.expand(batch_size, 1, -1)79embeddings = torch.cat([class_embeds, patch_embeds], dim=1)80embeddings = embeddings + self.position_embedding(self.position_ids)81return embeddings82
83
84class CLIPAttention(nn.Module):85"""Multi-headed attention from 'Attention Is All You Need' paper with xformers implementation"""86
87def __init__(self, config):88super().__init__()89self.config = config90self.embed_dim = config.hidden_size91self.num_heads = config.num_attention_heads92self.head_dim = self.embed_dim // self.num_heads93if self.head_dim * self.num_heads != self.embed_dim:94raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")95self.scale = self.head_dim**-0.596self.dropout = config.attention_dropout97
98self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)99self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)100self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)101self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)102
103def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):104return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()105
106def forward(107self,108hidden_states: torch.Tensor,109attention_mask: Optional[torch.Tensor] = None,110causal_attention_mask: Optional[torch.Tensor] = None,111output_attentions: Optional[bool] = False,112) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:113"""Input shape: Batch x Time x Channel"""114
115bsz, tgt_len, embed_dim = hidden_states.size()116
117# get query proj118query_states = self._shape(self.q_proj(hidden_states), -1, bsz)119key_states = self._shape(self.k_proj(hidden_states), -1, bsz)120value_states = self._shape(self.v_proj(hidden_states), -1, bsz)121
122attn_weights_reshaped = None123attn_output = xops.memory_efficient_attention(124query_states,125key_states,126value_states,127p=self.dropout,128scale=self.scale,129)130attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)131
132attn_output = self.out_proj(attn_output)133
134return attn_output, attn_weights_reshaped135
136
137class CLIPMLP(nn.Module):138def __init__(self, config):139super().__init__()140self.config = config141self.activation_fn = ACT2FN[config.hidden_act]142self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)143self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)144
145def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:146hidden_states = self.fc1(hidden_states)147hidden_states = self.activation_fn(hidden_states)148hidden_states = self.fc2(hidden_states)149return hidden_states150
151
152class CLIPEncoderLayer(nn.Module):153def __init__(self, config: CLIPConfig):154super().__init__()155self.embed_dim = config.hidden_size156self.self_attn = CLIPAttention(config)157self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)158self.mlp = CLIPMLP(config)159self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)160
161def forward(162self,163hidden_states: torch.Tensor,164attention_mask: torch.Tensor,165causal_attention_mask: torch.Tensor,166output_attentions: Optional[bool] = False,167) -> Tuple[torch.FloatTensor]:168"""169Args:
170hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
171attention_mask (`torch.FloatTensor`): attention mask of size
172`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
173`(config.encoder_attention_heads,)`.
174output_attentions (`bool`, *optional*):
175Whether or not to return the attentions tensors of all attention layers. See `attentions` under
176returned tensors for more detail.
177"""
178residual = hidden_states179
180hidden_states = self.layer_norm1(hidden_states)181hidden_states, attn_weights = self.self_attn(182hidden_states=hidden_states,183attention_mask=attention_mask,184causal_attention_mask=causal_attention_mask,185output_attentions=output_attentions,186)187hidden_states = residual + hidden_states188
189residual = hidden_states190hidden_states = self.layer_norm2(hidden_states)191hidden_states = self.mlp(hidden_states)192hidden_states = residual + hidden_states193
194outputs = (hidden_states,)195
196if output_attentions:197outputs += (attn_weights,)198
199return outputs200
201
202class CLIPPreTrainedModel(PreTrainedModel):203"""204An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
205models.
206"""
207
208config_class = CLIPConfig209base_model_prefix = "clip"210supports_gradient_checkpointing = True211_keys_to_ignore_on_load_missing = [r"position_ids"]212
213def _set_gradient_checkpointing(self, module, value=False):214if isinstance(module, CLIPEncoder):215module.gradient_checkpointing = value216
217
218CLIP_START_DOCSTRING = r"""219This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
220library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
221etc.)
222
223This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
224Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
225and behavior.
226
227Parameters:
228config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
229Initializing with a config file does not load the weights associated with the model, only the
230configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
231"""
232
233CLIP_VISION_INPUTS_DOCSTRING = r"""234Args:
235pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
236Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
237[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
238output_attentions (`bool`, *optional*):
239Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
240tensors for more detail.
241output_hidden_states (`bool`, *optional*):
242Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
243more detail.
244return_dict (`bool`, *optional*):
245Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
246"""
247
248CLIP_INPUTS_DOCSTRING = r"""249Args:
250input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
251Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
252it.
253
254Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
255[`PreTrainedTokenizer.__call__`] for details.
256
257[What are input IDs?](../glossary#input-ids)
258attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
259Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
260
261- 1 for tokens that are **not masked**,
262- 0 for tokens that are **masked**.
263
264[What are attention masks?](../glossary#attention-mask)
265position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
266Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
267config.max_position_embeddings - 1]`.
268
269[What are position IDs?](../glossary#position-ids)
270pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
271Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
272[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
273return_loss (`bool`, *optional*):
274Whether or not to return the contrastive loss.
275output_attentions (`bool`, *optional*):
276Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
277tensors for more detail.
278output_hidden_states (`bool`, *optional*):
279Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
280more detail.
281return_dict (`bool`, *optional*):
282Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
283"""
284
285
286class CLIPEncoder(nn.Module):287"""288Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
289[`CLIPEncoderLayer`].
290
291Args:
292config: CLIPConfig
293"""
294
295def __init__(self, config: CLIPConfig):296super().__init__()297self.config = config298self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])299self.gradient_checkpointing = False300
301def forward(302self,303inputs_embeds,304attention_mask: Optional[torch.Tensor] = None,305causal_attention_mask: Optional[torch.Tensor] = None,306output_attentions: Optional[bool] = None,307output_hidden_states: Optional[bool] = None,308return_dict: Optional[bool] = None,309) -> Union[Tuple, BaseModelOutput]:310r"""311Args:
312inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
313Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
314This is useful if you want more control over how to convert `input_ids` indices into associated vectors
315than the model's internal embedding lookup matrix.
316attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
317Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
318
319- 1 for tokens that are **not masked**,
320- 0 for tokens that are **masked**.
321
322[What are attention masks?](../glossary#attention-mask)
323causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
324Causal mask for the text model. Mask values selected in `[0, 1]`:
325
326- 1 for tokens that are **not masked**,
327- 0 for tokens that are **masked**.
328
329[What are attention masks?](../glossary#attention-mask)
330output_attentions (`bool`, *optional*):
331Whether or not to return the attentions tensors of all attention layers. See `attentions` under
332returned tensors for more detail.
333output_hidden_states (`bool`, *optional*):
334Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
335for more detail.
336return_dict (`bool`, *optional*):
337Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
338"""
339output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions340output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states341return_dict = return_dict if return_dict is not None else self.config.use_return_dict342
343encoder_states = () if output_hidden_states else None344all_attentions = () if output_attentions else None345
346hidden_states = inputs_embeds347for idx, encoder_layer in enumerate(self.layers):348if output_hidden_states:349encoder_states = encoder_states + (hidden_states,)350if self.gradient_checkpointing and self.training:351
352def create_custom_forward(module):353def custom_forward(*inputs):354return module(*inputs, output_attentions)355
356return custom_forward357
358layer_outputs = torch.utils.checkpoint.checkpoint(359create_custom_forward(encoder_layer),360hidden_states,361attention_mask,362causal_attention_mask,363)364else:365layer_outputs = encoder_layer(366hidden_states,367attention_mask,368causal_attention_mask,369output_attentions=output_attentions,370)371
372hidden_states = layer_outputs[0]373
374if output_attentions:375all_attentions = all_attentions + (layer_outputs[1],)376
377if output_hidden_states:378encoder_states = encoder_states + (hidden_states,)379
380if not return_dict:381return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)382return BaseModelOutput(383last_hidden_state=hidden_states,384hidden_states=encoder_states,385attentions=all_attentions,386)387
388config_class = CLIPTextConfig389
390_no_split_modules = ["CLIPEncoderLayer"]391
392
393class CLIPVisionTransformer(nn.Module):394def __init__(self, config: CLIPVisionConfig):395super().__init__()396self.config = config397embed_dim = config.hidden_size398
399self.embeddings = CLIPVisionEmbeddings(config)400self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)401self.encoder = CLIPEncoder(config)402self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)403
404@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)405@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)406def forward(407self,408pixel_values: Optional[torch.FloatTensor] = None,409output_attentions: Optional[bool] = None,410output_hidden_states: Optional[bool] = None,411return_dict: Optional[bool] = None,412) -> Union[Tuple, BaseModelOutputWithPooling]:413r"""414Returns:
415
416"""
417output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions418output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states419return_dict = return_dict if return_dict is not None else self.config.use_return_dict420
421if pixel_values is None:422raise ValueError("You have to specify pixel_values")423
424hidden_states = self.embeddings(pixel_values)425hidden_states = self.pre_layrnorm(hidden_states)426
427encoder_outputs = self.encoder(428inputs_embeds=hidden_states,429output_attentions=output_attentions,430output_hidden_states=output_hidden_states,431return_dict=return_dict,432)433
434last_hidden_state = encoder_outputs[0]435pooled_output = last_hidden_state[:, 0, :]436pooled_output = self.post_layernorm(pooled_output)437
438if not return_dict:439return (last_hidden_state, pooled_output) + encoder_outputs[1:]440
441return BaseModelOutputWithPooling(442last_hidden_state=last_hidden_state,443pooler_output=pooled_output,444hidden_states=encoder_outputs.hidden_states,445attentions=encoder_outputs.attentions,446)447
448
449@add_start_docstrings(450"""The vision model from CLIP without any head or projection on top.""",451CLIP_START_DOCSTRING,452)
453class CLIPVisionModel(CLIPPreTrainedModel):454config_class = CLIPVisionConfig455main_input_name = "pixel_values"456
457def __init__(self, config: CLIPVisionConfig):458super().__init__(config)459self.vision_model = CLIPVisionTransformer(config)460# Initialize weights and apply final processing461self.post_init()462
463def get_input_embeddings(self) -> nn.Module:464return self.vision_model.embeddings.patch_embedding465
466@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)467@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)468def forward(469self,470pixel_values: Optional[torch.FloatTensor] = None,471output_attentions: Optional[bool] = None,472output_hidden_states: Optional[bool] = None,473return_dict: Optional[bool] = None,474) -> Union[Tuple, BaseModelOutputWithPooling]:475r"""476Returns:
477
478Examples:
479
480```python
481>>> from PIL import Image
482>>> import requests
483>>> from transformers import AutoProcessor, CLIPVisionModel
484
485>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
486>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
487
488>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
489>>> image = Image.open(requests.get(url, stream=True).raw)
490
491>>> inputs = processor(images=image, return_tensors="pt")
492
493>>> outputs = model(**inputs)
494>>> last_hidden_state = outputs.last_hidden_state
495>>> pooled_output = outputs.pooler_output # pooled CLS states
496```"""
497return_dict = return_dict if return_dict is not None else self.config.use_return_dict498
499return self.vision_model(500pixel_values=pixel_values,501output_attentions=output_attentions,502output_hidden_states=output_hidden_states,503return_dict=return_dict,504)505