pytorch
68 строк · 1.7 Кб
1import numpy as np2
3from . import benchmark4
5
6class MatMulBench(benchmark.Benchmark):7def __init__(self, mode, device, dtype, B, M, N, K):8super().__init__(mode, device, dtype)9self.B = B10self.M = M11self.N = N12self.K = K13self.d1 = self.rand(14[B, M, N], device=device, dtype=dtype, requires_grad=self.requires_grad15)16self.d2 = self.rand(17[B, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad18)19self.inputs = [self.d1, self.d2]20
21def forward(self, d1, d2):22y = self.matmul(d1, d2)23return y24
25def reference(self):26return np.matmul(self.numpy(self.d1), self.numpy(self.d2))27
28def config(self):29return [self.B, self.M, self.N, self.K]30
31@staticmethod32def module():33return "batch_matmul"34
35def memory_workload(self):36if self.mode == "fwd":37sol_count = 138algorithmic_count = 139else:40sol_count = 1 + 141algorithmic_count = 1 + (1 + 1)42
43buffer_size = (44self.B * self.M * self.N45+ self.B * self.M * self.N46+ self.B * self.N * self.K47)48return {49"sol": buffer_size * sol_count,50"algorithmic": buffer_size * algorithmic_count,51}52
53def compute_workload(self):54if self.mode == "fwd":55count = 156else:57count = 1 + (1 + 1)58
59op_count = 2 * self.B * self.M * self.N * self.K60
61return op_count * count62
63@staticmethod64def default_configs():65return [[128, 64, 128, 256]]66
67
68benchmark.register_benchmark_class(MatMulBench)69