1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
17
static int test_gemm(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose)
21
pd.set(1, 1.f); // beta
24
pd.set(14, output_transpose);
30
std::vector<ncnn::Mat> weights(0);
32
std::vector<ncnn::Mat> a(2);
33
a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
34
a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
39
int ret = test_layer("Gemm", pd, weights, a);
42
fprintf(stderr, "test_gemm failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose);
48
static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K)
51
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0)
52
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0)
53
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0)
54
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0)
55
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1)
56
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1)
57
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1)
58
|| test_gemm(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1);
103
int mnk_count = sizeof(mnk) / sizeof(int) / 3;
104
int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3;
106
for (int i = 0; i < mnk_count; i++)
112
for (int j = 0; j < tile_mnk_count; j++)
114
int TILE_M = tile_mnk[j][0];
115
int TILE_N = tile_mnk[j][1];
116
int TILE_K = tile_mnk[j][2];
118
if (TILE_M >= M && TILE_N >= N && TILE_K >= K)
121
int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K);
127
int ret = test_gemm_0(M, N, K, 100, 100, 100);