pytorch
56 строк · 1.3 Кб
1import scipy.special2
3from . import benchmark4
5
6class SoftmaxBench(benchmark.Benchmark):7def __init__(self, mode, device, dtype, M, N):8super().__init__(mode, device, dtype)9self.M = M10self.N = N11self.dtype = dtype12self.inputs = [13self.randn(14[M, N], device=device, dtype=dtype, requires_grad=self.requires_grad15)16]17
18def forward(self, inputs):19x = self.add(inputs, 0.001)20y = self.softmax(x, dim=-1, dtype=self.dtype)21return y22
23def reference(self):24return scipy.special.softmax(self.numpy(self.inputs), axis=-1)25
26def config(self):27return [self.M, self.N]28
29@staticmethod30def module():31return "softmax"32
33def memory_workload(self):34if self.mode == "fwd":35sol_count = 1 + 136algorithmic_count = 3 + 137else:38sol_count = (1 + 1) + (1 + 1)39algorithmic_count = (3 + 1) + (3 + 1)40
41buffer_size = self.M * self.N42return {43"sol": buffer_size * sol_count,44"algorithmic": buffer_size * algorithmic_count,45}46
47@staticmethod48def default_configs():49return [50[480, 20],51[1 << 15, 32],52[128, 1 << 16],53]54
55
56benchmark.register_benchmark_class(SoftmaxBench)57