pytorch-image-models

Форк
0
80 строк · 2.5 Кб
1
""" Position Embedding Utilities
2

3
Hacked together by / Copyright 2022 Ross Wightman
4
"""
5
import logging
6
import math
7
from typing import List, Tuple, Optional, Union
8

9
import torch
10
import torch.nn.functional as F
11

12
from .helpers import to_2tuple
13

14
_logger = logging.getLogger(__name__)
15

16

17
def resample_abs_pos_embed(
18
        posemb,
19
        new_size: List[int],
20
        old_size: Optional[List[int]] = None,
21
        num_prefix_tokens: int = 1,
22
        interpolation: str = 'bicubic',
23
        antialias: bool = True,
24
        verbose: bool = False,
25
):
26
    # sort out sizes, assume square if old size not provided
27
    num_pos_tokens = posemb.shape[1]
28
    num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
29
    if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
30
        return posemb
31

32
    if old_size is None:
33
        hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
34
        old_size = hw, hw
35

36
    if num_prefix_tokens:
37
        posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
38
    else:
39
        posemb_prefix, posemb = None, posemb
40

41
    # do the interpolation
42
    embed_dim = posemb.shape[-1]
43
    orig_dtype = posemb.dtype
44
    posemb = posemb.float()  # interpolate needs float32
45
    posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
46
    posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
47
    posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
48
    posemb = posemb.to(orig_dtype)
49

50
    # add back extra (class, etc) prefix tokens
51
    if posemb_prefix is not None:
52
        posemb = torch.cat([posemb_prefix, posemb], dim=1)
53

54
    if not torch.jit.is_scripting() and verbose:
55
        _logger.info(f'Resized position embedding: {old_size} to {new_size}.')
56

57
    return posemb
58

59

60
def resample_abs_pos_embed_nhwc(
61
        posemb,
62
        new_size: List[int],
63
        interpolation: str = 'bicubic',
64
        antialias: bool = True,
65
        verbose: bool = False,
66
):
67
    if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
68
        return posemb
69

70
    orig_dtype = posemb.dtype
71
    posemb = posemb.float()
72
    # do the interpolation
73
    posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
74
    posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
75
    posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
76

77
    if not torch.jit.is_scripting() and verbose:
78
        _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
79

80
    return posemb
81

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.