pytorch-image-models
55 строк · 1.7 Кб
1import torch
2import torch.nn as nn
3
4
5class SpaceToDepth(nn.Module):
6bs: torch.jit.Final[int]
7
8def __init__(self, block_size=4):
9super().__init__()
10assert block_size == 4
11self.bs = block_size
12
13def forward(self, x):
14N, C, H, W = x.size()
15x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
16x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
17x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
18return x
19
20
21@torch.jit.script
22class SpaceToDepthJit:
23def __call__(self, x: torch.Tensor):
24# assuming hard-coded that block_size==4 for acceleration
25N, C, H, W = x.size()
26x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
27x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
28x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
29return x
30
31
32class SpaceToDepthModule(nn.Module):
33def __init__(self, no_jit=False):
34super().__init__()
35if not no_jit:
36self.op = SpaceToDepthJit()
37else:
38self.op = SpaceToDepth()
39
40def forward(self, x):
41return self.op(x)
42
43
44class DepthToSpace(nn.Module):
45
46def __init__(self, block_size):
47super().__init__()
48self.bs = block_size
49
50def forward(self, x):
51N, C, H, W = x.size()
52x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
53x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
54x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
55return x
56