1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
15
#include "unaryop_mips.h"
22
#include "msa_mathfun.h"
27
UnaryOp_mips::UnaryOp_mips()
30
support_packing = true;
35
static int unary_op_inplace(Mat& a, const Option& opt)
43
int elempack = a.elempack;
44
int size = w * h * d * elempack;
46
#pragma omp parallel for num_threads(opt.num_threads)
47
for (int q = 0; q < channels; q++)
49
float* ptr = a.channel(q);
53
for (; i + 3 < size; i += 4)
55
__builtin_prefetch(ptr + 16);
56
v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
57
_p = op.func_pack4(_p);
58
__msa_st_w((v4i32)_p, ptr, 0);
72
namespace UnaryOp_mips_functor {
76
float func(const float& x) const
78
return (float)fabs(x);
81
v4f32 func_pack4(const v4f32& x) const
83
return (v4f32)__msa_bclri_w((v4u32)x, 31);
90
float func(const float& x) const
95
v4f32 func_pack4(const v4f32& x) const
97
return (v4f32)__msa_bnegi_w((v4u32)x, 31);
104
float func(const float& x) const
106
return (float)floor(x);
109
v4f32 func_pack4(const v4f32& x) const
111
v4i32 _xi = __msa_ftrunc_s_w(x);
112
v4i32 _mask = __msa_fclt_w(x, __msa_ffint_s_w(_xi));
113
return __msa_ffint_s_w(__msa_addv_w(_xi, _mask));
114
// int old_msacsr = __msa_cfcmsa_msacsr();
115
// __msa_ctcmsa_msacsr(old_msacsr | 3); // round towards -inf
116
// v4f32 y = __msa_frint_w(x);
117
// __msa_ctcmsa_msacsr(old_msacsr);
125
float func(const float& x) const
127
return (float)ceil(x);
130
v4f32 func_pack4(const v4f32& x) const
132
v4i32 _xi = __msa_ftrunc_s_w(x);
133
v4i32 _mask = __msa_fclt_w(__msa_ffint_s_w(_xi), x);
134
return __msa_ffint_s_w(__msa_subv_w(_xi, _mask));
135
// int old_msacsr = __msa_cfcmsa_msacsr();
136
// __msa_ctcmsa_msacsr((old_msacsr | 3) ^ 1); // round towards +inf
137
// v4f32 y = __msa_frint_w(x);
138
// __msa_ctcmsa_msacsr(old_msacsr);
144
struct unary_op_square
146
float func(const float& x) const
151
v4f32 func_pack4(const v4f32& x) const
153
return __msa_fmul_w(x, x);
160
float func(const float& x) const
162
return (float)sqrt(x);
165
v4f32 func_pack4(const v4f32& x) const
167
return __msa_fsqrt_w(x);
174
float func(const float& x) const
176
return (float)(1.f / sqrt(x));
179
v4f32 func_pack4(const v4f32& x) const
181
return __msa_frsqrt_w(x);
188
float func(const float& x) const
190
return (float)exp(x);
193
v4f32 func_pack4(const v4f32& x) const
202
float func(const float& x) const
204
return (float)log(x);
207
v4f32 func_pack4(const v4f32& x) const
216
float func(const float& x) const
218
return (float)sin(x);
221
v4f32 func_pack4(const v4f32& x) const
225
__msa_st_w((v4i32)x, tmp, 0);
226
tmp[0] = sin(tmp[0]);
227
tmp[1] = sin(tmp[1]);
228
tmp[2] = sin(tmp[2]);
229
tmp[3] = sin(tmp[3]);
230
return (v4f32)__msa_ld_w(tmp, 0);
237
float func(const float& x) const
239
return (float)cos(x);
242
v4f32 func_pack4(const v4f32& x) const
246
__msa_st_w((v4i32)x, tmp, 0);
247
tmp[0] = cos(tmp[0]);
248
tmp[1] = cos(tmp[1]);
249
tmp[2] = cos(tmp[2]);
250
tmp[3] = cos(tmp[3]);
251
return (v4f32)__msa_ld_w(tmp, 0);
258
float func(const float& x) const
260
return (float)tan(x);
263
v4f32 func_pack4(const v4f32& x) const
267
__msa_st_w((v4i32)x, tmp, 0);
268
tmp[0] = tan(tmp[0]);
269
tmp[1] = tan(tmp[1]);
270
tmp[2] = tan(tmp[2]);
271
tmp[3] = tan(tmp[3]);
272
return (v4f32)__msa_ld_w(tmp, 0);
279
float func(const float& x) const
281
return (float)asin(x);
284
v4f32 func_pack4(const v4f32& x) const
288
__msa_st_w((v4i32)x, tmp, 0);
289
tmp[0] = asin(tmp[0]);
290
tmp[1] = asin(tmp[1]);
291
tmp[2] = asin(tmp[2]);
292
tmp[3] = asin(tmp[3]);
293
return (v4f32)__msa_ld_w(tmp, 0);
300
float func(const float& x) const
302
return (float)acos(x);
305
v4f32 func_pack4(const v4f32& x) const
309
__msa_st_w((v4i32)x, tmp, 0);
310
tmp[0] = acos(tmp[0]);
311
tmp[1] = acos(tmp[1]);
312
tmp[2] = acos(tmp[2]);
313
tmp[3] = acos(tmp[3]);
314
return (v4f32)__msa_ld_w(tmp, 0);
321
float func(const float& x) const
323
return (float)atan(x);
326
v4f32 func_pack4(const v4f32& x) const
330
__msa_st_w((v4i32)x, tmp, 0);
331
tmp[0] = atan(tmp[0]);
332
tmp[1] = atan(tmp[1]);
333
tmp[2] = atan(tmp[2]);
334
tmp[3] = atan(tmp[3]);
335
return (v4f32)__msa_ld_w(tmp, 0);
340
struct unary_op_reciprocal
342
float func(const float& x) const
347
v4f32 func_pack4(const v4f32& x) const
349
return __msa_frcp_w(x);
356
float func(const float& x) const
358
return (float)tanh(x);
361
v4f32 func_pack4(const v4f32& x) const
370
float func(const float& x) const
372
return (float)log10(x);
375
v4f32 func_pack4(const v4f32& x) const
377
return __msa_fmul_w(log_ps(x), __msa_fill_w_f32(0.434294481903));
384
float func(const float& x) const
386
// round to nearest even
387
#if NCNN_GNU_INLINE_ASM
388
// return (x + 12582912.f) - 12582912.f;
390
const float magic = 12582912.f;
392
"add.s %0, %1, %2 \n"
393
"sub.s %0, %0, %2 \n"
400
int old_rm = fegetround();
401
fesetround(FE_TONEAREST);
403
float y = nearbyintf(x);
411
v4f32 func_pack4(const v4f32& x) const
413
// round towards nearest even by default
414
return __msa_frint_w(x);
421
float func(const float& x) const
423
return (float)truncf(x);
426
v4f32 func_pack4(const v4f32& x) const
428
return __msa_ffint_s_w(__msa_ftrunc_s_w(x));
429
// int old_msacsr = __msa_cfcmsa_msacsr();
430
// __msa_ctcmsa_msacsr((old_msacsr | 3) ^ 2); // round towards zero
431
// v4f32 y = __msa_frint_w(x);
432
// __msa_ctcmsa_msacsr(old_msacsr);
438
} // namespace UnaryOp_mips_functor
440
int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
442
using namespace UnaryOp_mips_functor;
444
if (op_type == Operation_ABS)
445
return unary_op_inplace<unary_op_abs>(bottom_top_blob, opt);
447
if (op_type == Operation_NEG)
448
return unary_op_inplace<unary_op_neg>(bottom_top_blob, opt);
450
if (op_type == Operation_FLOOR)
451
return unary_op_inplace<unary_op_floor>(bottom_top_blob, opt);
453
if (op_type == Operation_CEIL)
454
return unary_op_inplace<unary_op_ceil>(bottom_top_blob, opt);
456
if (op_type == Operation_SQUARE)
457
return unary_op_inplace<unary_op_square>(bottom_top_blob, opt);
459
if (op_type == Operation_SQRT)
460
return unary_op_inplace<unary_op_sqrt>(bottom_top_blob, opt);
462
if (op_type == Operation_RSQRT)
463
return unary_op_inplace<unary_op_rsqrt>(bottom_top_blob, opt);
465
if (op_type == Operation_EXP)
466
return unary_op_inplace<unary_op_exp>(bottom_top_blob, opt);
468
if (op_type == Operation_LOG)
469
return unary_op_inplace<unary_op_log>(bottom_top_blob, opt);
471
if (op_type == Operation_SIN)
472
return unary_op_inplace<unary_op_sin>(bottom_top_blob, opt);
474
if (op_type == Operation_COS)
475
return unary_op_inplace<unary_op_cos>(bottom_top_blob, opt);
477
if (op_type == Operation_TAN)
478
return unary_op_inplace<unary_op_tan>(bottom_top_blob, opt);
480
if (op_type == Operation_ASIN)
481
return unary_op_inplace<unary_op_asin>(bottom_top_blob, opt);
483
if (op_type == Operation_ACOS)
484
return unary_op_inplace<unary_op_acos>(bottom_top_blob, opt);
486
if (op_type == Operation_ATAN)
487
return unary_op_inplace<unary_op_atan>(bottom_top_blob, opt);
489
if (op_type == Operation_RECIPROCAL)
490
return unary_op_inplace<unary_op_reciprocal>(bottom_top_blob, opt);
492
if (op_type == Operation_TANH)
493
return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt);
495
if (op_type == Operation_LOG10)
496
return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt);
498
if (op_type == Operation_ROUND)
499
return unary_op_inplace<unary_op_round>(bottom_top_blob, opt);
501
if (op_type == Operation_TRUNC)
502
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);