pytorch
106 строк · 2.9 Кб
1from . import benchmark2
3
4class ConvImplBench(benchmark.Benchmark):5def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC):6super().__init__(mode, device, dtype)7self.case = case8self.kernel_size = kernel_size9self.N = N10self.iC = iC11self.H = H12self.W = W13self.oC = oC14self.data = self.rand(15[N, iC, H, W], device=device, requires_grad=self.requires_grad16)17if case == "conv":18self.groups = 119elif case == "depthwise_conv":20self.groups = iC21else:22raise ValueError(f"invalid case: {case}")23
24self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups)25if device != "cpu":26self.to_device(self.conv, device)27
28def forward(self):29y = self.conv(self.data)30return y31
32def config(self):33return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC]34
35def memory_workload(self):36if self.mode == "fwd":37sol_count = {"i": 1, "o": 1, "k": 1}38algorithmic_count = {"i": 1, "o": 1, "k": 1}39else:40sol_count = {"i": 1 + 1, "o": 1 + 1, "k": 1 + 1}41algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)}42
43buffer_size = {44"i": self.N * self.iC * self.H * self.W,45"o": self.N * self.oC * self.H * self.W,46"k": self.oC47* (self.iC / self.groups)48* self.kernel_size49* self.kernel_size,50}51sol_size = 052algorithmic_size = 053for key in sol_count:54sol_size += buffer_size[key] * sol_count[key]55algorithmic_size += buffer_size[key] * algorithmic_count[key]56return {"sol": sol_size, "algorithmic": algorithmic_size}57
58def compute_workload(self):59if self.mode == "fwd":60count = 161elif self.mode == "both":62count = 1 + (1 + 1)63else:64raise ValueError(f"invalid mode: {self.mode}")65
66op_count = (67self.N68* self.iC69/ self.groups70* self.oC71* self.kernel_size72* self.kernel_size73* self.H74* self.W75)76op_count *= 277
78return op_count * count79
80@staticmethod81def default_configs():82return [83[3, 64, 32, 128, 128, 64],84]85
86
87class ConvBench(ConvImplBench):88def __init__(self, *args):89super().__init__("conv", *args)90
91@staticmethod92def module():93return "conv"94
95
96class DepthwiseConvBench(ConvImplBench):97def __init__(self, *args):98super().__init__("depthwise_conv", *args)99
100@staticmethod101def module():102return "depthwise_conv"103
104
105benchmark.register_benchmark_class(ConvBench)106benchmark.register_benchmark_class(DepthwiseConvBench)107