pytorch
137 строк · 4.1 Кб
1import numpy as np
2
3import torch
4
5from . import benchmark
6
7
8class Concat2D2InputBench(benchmark.Benchmark):
9def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
10super().__init__(mode, device, dtype)
11self.I1_D1 = I1_D1
12self.I1_D2 = I1_D2
13self.I2_D1 = I2_D1
14self.I2_D2 = I2_D2
15self.concat_dim = concat_dim
16self.input1 = self.randn(
17[I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
18)
19self.input2 = self.randn(
20[I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
21)
22self.inputs = [self.input1, self.input2]
23
24def forward(self, input1, input2):
25x1 = self.add(input1, 0.00001)
26x2 = self.add(input2, 0.00001)
27y = self.cat((x1, x2), dim=self.concat_dim)
28return y
29
30def reference(self):
31return np.concatenate(
32(self.numpy(self.input1), self.numpy(self.input2)),
33axis=self.concat_dim,
34)
35
36def config(self):
37return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
38
39@staticmethod
40def module():
41return "concat2d2input"
42
43def memory_workload(self):
44if self.mode == "fwd":
45sol_count = 1 + 1
46algorithmic_count = 3 + 1
47else:
48sol_count = (1 + 1) + (1 + 1)
49algorithmic_count = (3 + 1) + (3 + 1)
50
51buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
52return {
53"sol": buffer_size * sol_count,
54"algorithmic": buffer_size * algorithmic_count,
55}
56
57@staticmethod
58def default_configs():
59return [
60[1, 160, 1, 14, 1],
61[1, 580, 1, 174, 1],
62[20, 160, 20, 14, 1],
63[20, 580, 20, 174, 1],
64[8, 512, 8, 512, 1],
65[1 << 13, 1060, 1 << 13, 1040, 1],
66[1 << 13, 2000, 1 << 13, 1074, 1],
67[1 << 15, 1060, 1 << 15, 2670, 1],
68[1 << 15, 5120, 1 << 15, 2512, 1],
69]
70
71
72benchmark.register_benchmark_class(Concat2D2InputBench)
73
74
75class ConcatGraphOptBench(benchmark.Benchmark):
76def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
77super().__init__(mode, device, dtype)
78self.I1_D1 = I1_D1
79self.I1_D2 = I1_D2
80self.I2_D1 = I2_D1
81self.I2_D2 = I2_D2
82self.concat_dim = concat_dim
83self.input1 = self.randn(
84[I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
85)
86self.input2 = self.randn(
87[I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
88)
89self.inputs = [self.input1, self.input2]
90torch._C._jit_override_can_fuse_on_cpu(True)
91torch._C._jit_cat_wo_conditionals(True)
92
93def forward(self, input1, input2):
94x1 = self.add(input1, 0.00001)
95x2 = self.add(input2, 0.00001)
96y = self.cat((x1, x2), dim=self.concat_dim)
97z = self.relu(y)
98return z
99
100def reference(self):
101return np.concatenate(
102(self.numpy(self.input1), self.numpy(self.input2)),
103axis=self.concat_dim,
104)
105
106def config(self):
107return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
108
109@staticmethod
110def module():
111return "concatGraphOpt"
112
113def memory_workload(self):
114if self.mode == "fwd":
115sol_count = 1 + 1
116algorithmic_count = 3 + 1
117else:
118sol_count = (1 + 1) + (1 + 1)
119algorithmic_count = (3 + 1) + (3 + 1)
120
121buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
122return {
123"sol": buffer_size * sol_count,
124"algorithmic": buffer_size * algorithmic_count,
125}
126
127@staticmethod
128def default_configs():
129return [
130[1 << 13, 1060, 1 << 13, 1040, 1],
131[1 << 13, 2000, 1 << 13, 1074, 1],
132[1 << 15, 1060, 1 << 15, 2670, 1],
133[1 << 15, 5120, 1 << 15, 2512, 1],
134]
135
136
137benchmark.register_benchmark_class(ConcatGraphOptBench)
138