pytorch-image-models

Форк
0
/
space_to_depth.py 
55 строк · 1.7 Кб
1
import torch
2
import torch.nn as nn
3

4

5
class SpaceToDepth(nn.Module):
6
    bs: torch.jit.Final[int]
7

8
    def __init__(self, block_size=4):
9
        super().__init__()
10
        assert block_size == 4
11
        self.bs = block_size
12

13
    def forward(self, x):
14
        N, C, H, W = x.size()
15
        x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs)  # (N, C, H//bs, bs, W//bs, bs)
16
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)
17
        x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs)  # (N, C*bs^2, H//bs, W//bs)
18
        return x
19

20

21
@torch.jit.script
22
class SpaceToDepthJit:
23
    def __call__(self, x: torch.Tensor):
24
        # assuming hard-coded that block_size==4 for acceleration
25
        N, C, H, W = x.size()
26
        x = x.view(N, C, H // 4, 4, W // 4, 4)  # (N, C, H//bs, bs, W//bs, bs)
27
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)
28
        x = x.view(N, C * 16, H // 4, W // 4)  # (N, C*bs^2, H//bs, W//bs)
29
        return x
30

31

32
class SpaceToDepthModule(nn.Module):
33
    def __init__(self, no_jit=False):
34
        super().__init__()
35
        if not no_jit:
36
            self.op = SpaceToDepthJit()
37
        else:
38
            self.op = SpaceToDepth()
39

40
    def forward(self, x):
41
        return self.op(x)
42

43

44
class DepthToSpace(nn.Module):
45

46
    def __init__(self, block_size):
47
        super().__init__()
48
        self.bs = block_size
49

50
    def forward(self, x):
51
        N, C, H, W = x.size()
52
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)  # (N, bs, bs, C//bs^2, H, W)
53
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)
54
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)
55
        return x
56

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.