pytorch-image-models

Форк
0
58 строк · 1.1 Кб
1
from enum import Enum
2
from typing import Union
3

4
import torch
5

6

7
class Format(str, Enum):
8
    NCHW = 'NCHW'
9
    NHWC = 'NHWC'
10
    NCL = 'NCL'
11
    NLC = 'NLC'
12

13

14
FormatT = Union[str, Format]
15

16

17
def get_spatial_dim(fmt: FormatT):
18
    fmt = Format(fmt)
19
    if fmt is Format.NLC:
20
        dim = (1,)
21
    elif fmt is Format.NCL:
22
        dim = (2,)
23
    elif fmt is Format.NHWC:
24
        dim = (1, 2)
25
    else:
26
        dim = (2, 3)
27
    return dim
28

29

30
def get_channel_dim(fmt: FormatT):
31
    fmt = Format(fmt)
32
    if fmt is Format.NHWC:
33
        dim = 3
34
    elif fmt is Format.NLC:
35
        dim = 2
36
    else:
37
        dim = 1
38
    return dim
39

40

41
def nchw_to(x: torch.Tensor, fmt: Format):
42
    if fmt == Format.NHWC:
43
        x = x.permute(0, 2, 3, 1)
44
    elif fmt == Format.NLC:
45
        x = x.flatten(2).transpose(1, 2)
46
    elif fmt == Format.NCL:
47
        x = x.flatten(2)
48
    return x
49

50

51
def nhwc_to(x: torch.Tensor, fmt: Format):
52
    if fmt == Format.NCHW:
53
        x = x.permute(0, 3, 1, 2)
54
    elif fmt == Format.NLC:
55
        x = x.flatten(1, 2)
56
    elif fmt == Format.NCL:
57
        x = x.flatten(1, 2).transpose(1, 2)
58
    return x
59

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.