pytorch-image-models
58 строк · 1.1 Кб
1from enum import Enum2from typing import Union3
4import torch5
6
7class Format(str, Enum):8NCHW = 'NCHW'9NHWC = 'NHWC'10NCL = 'NCL'11NLC = 'NLC'12
13
14FormatT = Union[str, Format]15
16
17def get_spatial_dim(fmt: FormatT):18fmt = Format(fmt)19if fmt is Format.NLC:20dim = (1,)21elif fmt is Format.NCL:22dim = (2,)23elif fmt is Format.NHWC:24dim = (1, 2)25else:26dim = (2, 3)27return dim28
29
30def get_channel_dim(fmt: FormatT):31fmt = Format(fmt)32if fmt is Format.NHWC:33dim = 334elif fmt is Format.NLC:35dim = 236else:37dim = 138return dim39
40
41def nchw_to(x: torch.Tensor, fmt: Format):42if fmt == Format.NHWC:43x = x.permute(0, 2, 3, 1)44elif fmt == Format.NLC:45x = x.flatten(2).transpose(1, 2)46elif fmt == Format.NCL:47x = x.flatten(2)48return x49
50
51def nhwc_to(x: torch.Tensor, fmt: Format):52if fmt == Format.NCHW:53x = x.permute(0, 3, 1, 2)54elif fmt == Format.NLC:55x = x.flatten(1, 2)56elif fmt == Format.NCL:57x = x.flatten(1, 2).transpose(1, 2)58return x59