pytorch-image-models
80 строк · 2.5 Кб
1""" Position Embedding Utilities
2
3Hacked together by / Copyright 2022 Ross Wightman
4"""
5import logging6import math7from typing import List, Tuple, Optional, Union8
9import torch10import torch.nn.functional as F11
12from .helpers import to_2tuple13
14_logger = logging.getLogger(__name__)15
16
17def resample_abs_pos_embed(18posemb,19new_size: List[int],20old_size: Optional[List[int]] = None,21num_prefix_tokens: int = 1,22interpolation: str = 'bicubic',23antialias: bool = True,24verbose: bool = False,25):26# sort out sizes, assume square if old size not provided27num_pos_tokens = posemb.shape[1]28num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens29if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:30return posemb31
32if old_size is None:33hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))34old_size = hw, hw35
36if num_prefix_tokens:37posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]38else:39posemb_prefix, posemb = None, posemb40
41# do the interpolation42embed_dim = posemb.shape[-1]43orig_dtype = posemb.dtype44posemb = posemb.float() # interpolate needs float3245posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)46posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)47posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)48posemb = posemb.to(orig_dtype)49
50# add back extra (class, etc) prefix tokens51if posemb_prefix is not None:52posemb = torch.cat([posemb_prefix, posemb], dim=1)53
54if not torch.jit.is_scripting() and verbose:55_logger.info(f'Resized position embedding: {old_size} to {new_size}.')56
57return posemb58
59
60def resample_abs_pos_embed_nhwc(61posemb,62new_size: List[int],63interpolation: str = 'bicubic',64antialias: bool = True,65verbose: bool = False,66):67if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:68return posemb69
70orig_dtype = posemb.dtype71posemb = posemb.float()72# do the interpolation73posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)74posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)75posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)76
77if not torch.jit.is_scripting() and verbose:78_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')79
80return posemb81