pytorch

Форк
0
137 строк · 4.1 Кб
1
import numpy as np
2

3
import torch
4

5
from . import benchmark
6

7

8
class Concat2D2InputBench(benchmark.Benchmark):
9
    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
10
        super().__init__(mode, device, dtype)
11
        self.I1_D1 = I1_D1
12
        self.I1_D2 = I1_D2
13
        self.I2_D1 = I2_D1
14
        self.I2_D2 = I2_D2
15
        self.concat_dim = concat_dim
16
        self.input1 = self.randn(
17
            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
18
        )
19
        self.input2 = self.randn(
20
            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
21
        )
22
        self.inputs = [self.input1, self.input2]
23

24
    def forward(self, input1, input2):
25
        x1 = self.add(input1, 0.00001)
26
        x2 = self.add(input2, 0.00001)
27
        y = self.cat((x1, x2), dim=self.concat_dim)
28
        return y
29

30
    def reference(self):
31
        return np.concatenate(
32
            (self.numpy(self.input1), self.numpy(self.input2)),
33
            axis=self.concat_dim,
34
        )
35

36
    def config(self):
37
        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
38

39
    @staticmethod
40
    def module():
41
        return "concat2d2input"
42

43
    def memory_workload(self):
44
        if self.mode == "fwd":
45
            sol_count = 1 + 1
46
            algorithmic_count = 3 + 1
47
        else:
48
            sol_count = (1 + 1) + (1 + 1)
49
            algorithmic_count = (3 + 1) + (3 + 1)
50

51
        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
52
        return {
53
            "sol": buffer_size * sol_count,
54
            "algorithmic": buffer_size * algorithmic_count,
55
        }
56

57
    @staticmethod
58
    def default_configs():
59
        return [
60
            [1, 160, 1, 14, 1],
61
            [1, 580, 1, 174, 1],
62
            [20, 160, 20, 14, 1],
63
            [20, 580, 20, 174, 1],
64
            [8, 512, 8, 512, 1],
65
            [1 << 13, 1060, 1 << 13, 1040, 1],
66
            [1 << 13, 2000, 1 << 13, 1074, 1],
67
            [1 << 15, 1060, 1 << 15, 2670, 1],
68
            [1 << 15, 5120, 1 << 15, 2512, 1],
69
        ]
70

71

72
benchmark.register_benchmark_class(Concat2D2InputBench)
73

74

75
class ConcatGraphOptBench(benchmark.Benchmark):
76
    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
77
        super().__init__(mode, device, dtype)
78
        self.I1_D1 = I1_D1
79
        self.I1_D2 = I1_D2
80
        self.I2_D1 = I2_D1
81
        self.I2_D2 = I2_D2
82
        self.concat_dim = concat_dim
83
        self.input1 = self.randn(
84
            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
85
        )
86
        self.input2 = self.randn(
87
            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
88
        )
89
        self.inputs = [self.input1, self.input2]
90
        torch._C._jit_override_can_fuse_on_cpu(True)
91
        torch._C._jit_cat_wo_conditionals(True)
92

93
    def forward(self, input1, input2):
94
        x1 = self.add(input1, 0.00001)
95
        x2 = self.add(input2, 0.00001)
96
        y = self.cat((x1, x2), dim=self.concat_dim)
97
        z = self.relu(y)
98
        return z
99

100
    def reference(self):
101
        return np.concatenate(
102
            (self.numpy(self.input1), self.numpy(self.input2)),
103
            axis=self.concat_dim,
104
        )
105

106
    def config(self):
107
        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
108

109
    @staticmethod
110
    def module():
111
        return "concatGraphOpt"
112

113
    def memory_workload(self):
114
        if self.mode == "fwd":
115
            sol_count = 1 + 1
116
            algorithmic_count = 3 + 1
117
        else:
118
            sol_count = (1 + 1) + (1 + 1)
119
            algorithmic_count = (3 + 1) + (3 + 1)
120

121
        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
122
        return {
123
            "sol": buffer_size * sol_count,
124
            "algorithmic": buffer_size * algorithmic_count,
125
        }
126

127
    @staticmethod
128
    def default_configs():
129
        return [
130
            [1 << 13, 1060, 1 << 13, 1040, 1],
131
            [1 << 13, 2000, 1 << 13, 1074, 1],
132
            [1 << 15, 1060, 1 << 15, 2670, 1],
133
            [1 << 15, 5120, 1 << 15, 2512, 1],
134
        ]
135

136

137
benchmark.register_benchmark_class(ConcatGraphOptBench)
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.