ncnn

Форк
0
/
test_gemm_1.cpp 
133 строки · 3.8 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "testutil.h"
16

17
static int test_gemm(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose)
18
{
19
    ncnn::ParamDict pd;
20
    pd.set(0, alpha);
21
    pd.set(1, 1.f); // beta
22
    pd.set(2, transA);
23
    pd.set(3, transB);
24
    pd.set(14, output_transpose);
25

26
    pd.set(20, TILE_M);
27
    pd.set(21, TILE_N);
28
    pd.set(22, TILE_K);
29

30
    std::vector<ncnn::Mat> weights(0);
31

32
    std::vector<ncnn::Mat> a(2);
33
    a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
34
    a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
35

36
    Randomize(a[0]);
37
    Randomize(a[1]);
38

39
    int ret = test_layer("Gemm", pd, weights, a);
40
    if (ret != 0)
41
    {
42
        fprintf(stderr, "test_gemm failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose);
43
    }
44

45
    return ret;
46
}
47

48
static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K)
49
{
50
    return 0
51
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0)
52
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0)
53
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0)
54
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0)
55
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1)
56
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1)
57
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1)
58
           || test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1);
59
}
60

61
int main()
62
{
63
    SRAND(7767517);
64

65
    int mnk[][3] = {
66
        {1, 1, 1},
67
        {2, 2, 2},
68
        {3, 3, 3},
69
        {4, 4, 4},
70
        {5, 5, 5},
71
        {6, 6, 6},
72
        {7, 7, 7},
73
        {8, 8, 8},
74
        {15, 15, 15},
75
        {16, 16, 16},
76
        {24, 24, 24},
77
        {31, 31, 31},
78
        {31, 32, 31},
79
        {32, 31, 32},
80
        {32, 32, 32},
81
        {20, 32, 20},
82
        {40, 40, 40},
83
        {47, 47, 47},
84
        {48, 48, 48},
85
        {52, 52, 52},
86
        {63, 64, 63},
87
        {64, 63, 64},
88
        {64, 64, 64}
89
    };
90

91
    int tile_mnk[][3] = {
92
        {1, 1, 1},
93
        {2, 2, 2},
94
        {4, 4, 4},
95
        {8, 8, 8},
96
        {12, 12, 12},
97
        {16, 16, 16},
98
        {20, 20, 20},
99
        {24, 24, 24},
100
        {28, 28, 28}
101
    };
102

103
    int mnk_count = sizeof(mnk) / sizeof(int) / 3;
104
    int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3;
105

106
    for (int i = 0; i < mnk_count; i++)
107
    {
108
        int M = mnk[i][0];
109
        int N = mnk[i][1];
110
        int K = mnk[i][2];
111

112
        for (int j = 0; j < tile_mnk_count; j++)
113
        {
114
            int TILE_M = tile_mnk[j][0];
115
            int TILE_N = tile_mnk[j][1];
116
            int TILE_K = tile_mnk[j][2];
117

118
            if (TILE_M >= M && TILE_N >= N && TILE_K >= K)
119
                continue;
120

121
            int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K);
122
            if (ret != 0)
123
                return 0;
124
        }
125

126
        // test no tiling
127
        int ret = test_gemm_0(M, N, K, 100, 100, 100);
128
        if (ret != 0)
129
            return 0;
130
    }
131

132
    return 0;
133
}
134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.