pytorch

Форк
0
106 строк · 2.9 Кб
1
from . import benchmark
2

3

4
class ConvImplBench(benchmark.Benchmark):
5
    def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC):
6
        super().__init__(mode, device, dtype)
7
        self.case = case
8
        self.kernel_size = kernel_size
9
        self.N = N
10
        self.iC = iC
11
        self.H = H
12
        self.W = W
13
        self.oC = oC
14
        self.data = self.rand(
15
            [N, iC, H, W], device=device, requires_grad=self.requires_grad
16
        )
17
        if case == "conv":
18
            self.groups = 1
19
        elif case == "depthwise_conv":
20
            self.groups = iC
21
        else:
22
            raise ValueError(f"invalid case: {case}")
23

24
        self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups)
25
        if device != "cpu":
26
            self.to_device(self.conv, device)
27

28
    def forward(self):
29
        y = self.conv(self.data)
30
        return y
31

32
    def config(self):
33
        return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC]
34

35
    def memory_workload(self):
36
        if self.mode == "fwd":
37
            sol_count = {"i": 1, "o": 1, "k": 1}
38
            algorithmic_count = {"i": 1, "o": 1, "k": 1}
39
        else:
40
            sol_count = {"i": 1 + 1, "o": 1 + 1, "k": 1 + 1}
41
            algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)}
42

43
        buffer_size = {
44
            "i": self.N * self.iC * self.H * self.W,
45
            "o": self.N * self.oC * self.H * self.W,
46
            "k": self.oC
47
            * (self.iC / self.groups)
48
            * self.kernel_size
49
            * self.kernel_size,
50
        }
51
        sol_size = 0
52
        algorithmic_size = 0
53
        for key in sol_count:
54
            sol_size += buffer_size[key] * sol_count[key]
55
            algorithmic_size += buffer_size[key] * algorithmic_count[key]
56
        return {"sol": sol_size, "algorithmic": algorithmic_size}
57

58
    def compute_workload(self):
59
        if self.mode == "fwd":
60
            count = 1
61
        elif self.mode == "both":
62
            count = 1 + (1 + 1)
63
        else:
64
            raise ValueError(f"invalid mode: {self.mode}")
65

66
        op_count = (
67
            self.N
68
            * self.iC
69
            / self.groups
70
            * self.oC
71
            * self.kernel_size
72
            * self.kernel_size
73
            * self.H
74
            * self.W
75
        )
76
        op_count *= 2
77

78
        return op_count * count
79

80
    @staticmethod
81
    def default_configs():
82
        return [
83
            [3, 64, 32, 128, 128, 64],
84
        ]
85

86

87
class ConvBench(ConvImplBench):
88
    def __init__(self, *args):
89
        super().__init__("conv", *args)
90

91
    @staticmethod
92
    def module():
93
        return "conv"
94

95

96
class DepthwiseConvBench(ConvImplBench):
97
    def __init__(self, *args):
98
        super().__init__("depthwise_conv", *args)
99

100
    @staticmethod
101
    def module():
102
        return "depthwise_conv"
103

104

105
benchmark.register_benchmark_class(ConvBench)
106
benchmark.register_benchmark_class(DepthwiseConvBench)
107

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.