pytorch

Форк
0
68 строк · 1.7 Кб
1
import numpy as np
2

3
from . import benchmark
4

5

6
class MatMulBench(benchmark.Benchmark):
7
    def __init__(self, mode, device, dtype, B, M, N, K):
8
        super().__init__(mode, device, dtype)
9
        self.B = B
10
        self.M = M
11
        self.N = N
12
        self.K = K
13
        self.d1 = self.rand(
14
            [B, M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
15
        )
16
        self.d2 = self.rand(
17
            [B, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
18
        )
19
        self.inputs = [self.d1, self.d2]
20

21
    def forward(self, d1, d2):
22
        y = self.matmul(d1, d2)
23
        return y
24

25
    def reference(self):
26
        return np.matmul(self.numpy(self.d1), self.numpy(self.d2))
27

28
    def config(self):
29
        return [self.B, self.M, self.N, self.K]
30

31
    @staticmethod
32
    def module():
33
        return "batch_matmul"
34

35
    def memory_workload(self):
36
        if self.mode == "fwd":
37
            sol_count = 1
38
            algorithmic_count = 1
39
        else:
40
            sol_count = 1 + 1
41
            algorithmic_count = 1 + (1 + 1)
42

43
        buffer_size = (
44
            self.B * self.M * self.N
45
            + self.B * self.M * self.N
46
            + self.B * self.N * self.K
47
        )
48
        return {
49
            "sol": buffer_size * sol_count,
50
            "algorithmic": buffer_size * algorithmic_count,
51
        }
52

53
    def compute_workload(self):
54
        if self.mode == "fwd":
55
            count = 1
56
        else:
57
            count = 1 + (1 + 1)
58

59
        op_count = 2 * self.B * self.M * self.N * self.K
60

61
        return op_count * count
62

63
    @staticmethod
64
    def default_configs():
65
        return [[128, 64, 128, 256]]
66

67

68
benchmark.register_benchmark_class(MatMulBench)
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.