1
// Tencent is pleased to support the open source community by making ncnn available.
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
8
// https://opensource.org/licenses/BSD-3-Clause
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
15
#include "binaryop_mips.h"
19
#include "msa_mathfun.h"
24
BinaryOp_mips::BinaryOp_mips()
27
support_packing = true;
32
static void binary_op_vector_no_broadcast(const float* ptr, const float* ptr1, float* outptr, int size)
38
for (; i + 3 < size; i += 4)
40
__builtin_prefetch(ptr + 16);
41
__builtin_prefetch(ptr1 + 16);
42
v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
43
v4f32 _b = (v4f32)__msa_ld_w(ptr1, 0);
44
v4f32 _outp = op(_p, _b);
45
__msa_st_w((v4i32)_outp, outptr, 0);
53
*outptr = op(*ptr, *ptr1);
61
static void binary_op_vector_broadcast_b(const float* ptr, const float* ptr1, float* outptr, int size, int elempack)
65
const float b = *ptr1;
67
v4f32 _b_128 = (elempack == 4) ? (v4f32)__msa_ld_w(ptr1, 0) : __msa_fill_w_f32(b);
72
for (; i + 3 < size; i += 4)
74
__builtin_prefetch(ptr + 16);
75
v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
76
v4f32 _outp = op(_p, _b_128);
77
__msa_st_w((v4i32)_outp, outptr, 0);
84
*outptr = op(*ptr, b);
91
static void binary_op_vector_broadcast_a(const float* ptr, const float* ptr1, float* outptr, int size, int elempack)
97
v4f32 _a_128 = (elempack == 4) ? (v4f32)__msa_ld_w(ptr, 0) : __msa_fill_w_f32(a);
102
for (; i + 3 < size; i += 4)
104
__builtin_prefetch(ptr1 + 16);
105
v4f32 _b = (v4f32)__msa_ld_w(ptr1, 0);
106
v4f32 _outp = op(_a_128, _b);
107
__msa_st_w((v4i32)_outp, outptr, 0);
112
for (; i < size; i++)
114
*outptr = op(a, *ptr1);
121
static void binary_op_vector_broadcast_pb(const float* ptr, const float* ptr1, float* outptr, int w, int elempack)
131
__builtin_prefetch(ptr + 16);
132
v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
133
v4f32 _b = __msa_fill_w_f32(*ptr1);
134
v4f32 _outp = op(_p, _b);
135
__msa_st_w((v4i32)_outp, outptr, 0);
145
static void binary_op_vector_broadcast_pb_b(const float* ptr, const float* ptr1, float* outptr, int w, int elempack)
149
const int size = w * elempack;
153
v4f32 _b = __msa_fill_w_f32(*ptr1);
154
for (; i + 3 < size; i += 4)
156
__builtin_prefetch(ptr + 16);
157
v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
158
v4f32 _outp = op(_p, _b);
159
__msa_st_w((v4i32)_outp, outptr, 0);
167
static void binary_op_vector_broadcast_pb_a(const float* ptr, const float* ptr1, float* outptr, int w, int elempack)
175
v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
178
v4f32 _b = __msa_fill_w_f32(*ptr1);
179
v4f32 _outp = op(_p, _b);
180
__msa_st_w((v4i32)_outp, outptr, 0);
189
static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp)
191
const int w = std::max(aw, bw);
192
const int elempack = std::max(ap, bp);
193
const int size = w * elempack;
200
return binary_op_vector_no_broadcast<Op>(ptr, ptr1, outptr, size);
205
// broadcast single b
206
return binary_op_vector_broadcast_b<Op>(ptr, ptr1, outptr, size, elempack);
211
// broadcast single a
212
return binary_op_vector_broadcast_a<Op>(ptr, ptr1, outptr, size, elempack);
221
return binary_op_vector_broadcast_pb<Op>(ptr, ptr1, outptr, w, elempack);
226
// broadcast pack1 single b
227
return binary_op_vector_broadcast_pb_b<Op>(ptr, ptr1, outptr, w, elempack);
232
// broadcast single a and pack1 b
233
return binary_op_vector_broadcast_pb_a<Op>(ptr, ptr1, outptr, w, elempack);
237
// shall never reach here
241
static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt)
245
const int channels = a.c;
246
const int size = a.w * a.h * a.d * a.elempack;
248
#pragma omp parallel for num_threads(opt.num_threads)
249
for (int q = 0; q < channels; q++)
251
float* ptr = a.channel(q);
255
v4f32 _b = __msa_fill_w_f32(b);
256
for (; i + 3 < size; i += 4)
258
__builtin_prefetch(ptr + 16);
259
v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
261
__msa_st_w((v4i32)_p, ptr, 0);
265
for (; i < size; i++)
275
namespace BinaryOp_mips_functor {
278
#define MAKE_FUNCTION(NAME, IMPL, IMPL4) \
281
float operator()(const float& x, const float& y) const \
285
v4f32 operator()(const v4f32& x, const v4f32& y) const \
291
#define MAKE_FUNCTION(NAME, IMPL, IMPL4) \
294
float operator()(const float& x, const float& y) const \
303
MAKE_FUNCTION(binary_op_add, x + y, __msa_fadd_w(x, y))
304
MAKE_FUNCTION(binary_op_sub, x - y, __msa_fsub_w(x, y))
305
MAKE_FUNCTION(binary_op_mul, x * y, __msa_fmul_w(x, y))
306
MAKE_FUNCTION(binary_op_div, x / y, __msa_fdiv_w(x, y))
307
MAKE_FUNCTION(binary_op_max, std::max(x, y), __msa_fmax_w(x, y))
308
MAKE_FUNCTION(binary_op_min, std::min(x, y), __msa_fmin_w(x, y))
309
MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y))
310
MAKE_FUNCTION(binary_op_rsub, y - x, __msa_fsub_w(y, x))
311
MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x))
312
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
313
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
314
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
320
} // namespace BinaryOp_mips_functor
322
static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type)
324
using namespace BinaryOp_mips_functor;
326
if (op_type == BinaryOp::Operation_ADD) return binary_op_vector<binary_op_add>(ptr, ptr1, outptr, aw, bw, ap, bp);
327
if (op_type == BinaryOp::Operation_SUB) return binary_op_vector<binary_op_sub>(ptr, ptr1, outptr, aw, bw, ap, bp);
328
if (op_type == BinaryOp::Operation_MUL) return binary_op_vector<binary_op_mul>(ptr, ptr1, outptr, aw, bw, ap, bp);
329
if (op_type == BinaryOp::Operation_DIV) return binary_op_vector<binary_op_div>(ptr, ptr1, outptr, aw, bw, ap, bp);
330
if (op_type == BinaryOp::Operation_MAX) return binary_op_vector<binary_op_max>(ptr, ptr1, outptr, aw, bw, ap, bp);
331
if (op_type == BinaryOp::Operation_MIN) return binary_op_vector<binary_op_min>(ptr, ptr1, outptr, aw, bw, ap, bp);
332
if (op_type == BinaryOp::Operation_POW) return binary_op_vector<binary_op_pow>(ptr, ptr1, outptr, aw, bw, ap, bp);
333
if (op_type == BinaryOp::Operation_RSUB) return binary_op_vector<binary_op_rsub>(ptr, ptr1, outptr, aw, bw, ap, bp);
334
if (op_type == BinaryOp::Operation_RDIV) return binary_op_vector<binary_op_rdiv>(ptr, ptr1, outptr, aw, bw, ap, bp);
335
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
336
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
337
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
339
// should never reach here
342
static void binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt)
344
const int channels = a.c;
345
const int size = a.w * a.h * a.d * a.elempack;
347
#pragma omp parallel for num_threads(opt.num_threads)
348
for (int q = 0; q < channels; q++)
350
const float* ptr = a.channel(q);
351
float* outptr = c.channel(q);
353
binary_op_vector(ptr, &b, outptr, size, 1, 1, 1, op_type);
357
static void binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt)
359
const int channels = a.c;
360
const int size = a.w * a.h * a.d * a.elempack;
362
#pragma omp parallel for num_threads(opt.num_threads)
363
for (int q = 0; q < channels; q++)
365
const float* ptr = a.channel(q);
366
const float* ptr1 = b.channel(q);
367
float* outptr = c.channel(q);
369
binary_op_vector(ptr, ptr1, outptr, size, size, 1, 1, op_type);
373
static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt)
375
if (b.w * b.h * b.d * b.c * b.elempack == 1)
377
return binary_op_scalar(a, b[0], c, op_type, opt);
380
if (a.dims == b.dims && a.w == b.w && a.h == b.h && a.d == b.d && a.c == b.c && a.elempack == b.elempack)
382
return binary_op_no_broadcast(a, b, c, op_type, opt);
385
const int dims = c.dims;
391
#pragma omp parallel for num_threads(opt.num_threads)
392
for (int y = 0; y < h; y++)
394
const int y0 = std::min(y, a.h - 1);
395
const int y1 = std::min(y, b.h - 1);
397
const float* ptr = a.row(y0);
398
const float* ptr1 = b.row(y1);
399
float* outptr = c.row(y);
401
binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type);
405
if (dims == 3 || dims == 4)
407
const int channels = c.c;
409
#pragma omp parallel for num_threads(opt.num_threads)
410
for (int q = 0; q < channels; q++)
412
const int q0 = std::min(q, a.c - 1);
413
const int q1 = std::min(q, b.c - 1);
415
if (b.d * b.h * b.w == 1)
417
const float* ptr = a.channel(q0);
418
const float* ptr1 = b.channel(q1);
419
float* outptr = c.channel(q);
421
binary_op_vector(ptr, ptr1, outptr, a.w * a.h * a.d, 1, a.elempack, b.elempack, op_type);
427
for (int z = 0; z < c.d; z++)
429
const int z0 = std::min(z, a.d - 1);
430
const int z1 = std::min(z, b.d - 1);
432
const float* ptr = a.channel(q0).depth(z0);
433
const float* ptr1 = b.channel(q1).depth(z1);
434
float* outptr = c.channel(q).depth(z);
436
binary_op_vector(ptr, ptr1, outptr, a.w * a.h, 1, a.elempack, b.elempack, op_type);
441
for (int z = 0; z < c.d; z++)
443
const int z0 = std::min(z, a.d - 1);
444
const int z1 = std::min(z, b.d - 1);
446
for (int y = 0; y < c.h; y++)
448
const int y0 = std::min(y, a.h - 1);
449
const int y1 = std::min(y, b.h - 1);
451
const float* ptr = a.channel(q0).depth(z0).row(y0);
452
const float* ptr1 = b.channel(q1).depth(z1).row(y1);
453
float* outptr = c.channel(q).depth(z).row(y);
455
binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type);
462
static void binary_op_scalar_inplace(Mat& a, float b, int op_type, const Option& opt)
464
const int channels = a.c;
465
const int size = a.w * a.h * a.d * a.elempack;
467
#pragma omp parallel for num_threads(opt.num_threads)
468
for (int q = 0; q < channels; q++)
470
float* ptr = a.channel(q);
472
binary_op_vector(ptr, &b, ptr, size, 1, 1, 1, op_type);
476
static int get_reverse_op_type(int op_type)
478
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
479
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
480
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
481
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
482
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
483
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
484
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
485
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
489
int BinaryOp_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
491
const Mat& A = bottom_blobs[0];
492
const Mat& B = bottom_blobs[1];
493
const int outdims = std::max(A.dims, B.dims);
497
if (A.dims < outdims)
502
if (A.w * A.elempack == B.h * B.elempack)
503
A2 = A.reshape(1, A.w, opt.workspace_allocator);
504
else // if (A.w == B.w)
507
A2.w = A.w * A.elempack;
509
A2.elemsize = A.elemsize / A.elempack;
513
if (outdims == 3 && A.dims == 1)
515
if (A.w * A.elempack == B.c * B.elempack)
516
A2 = A.reshape(1, 1, A.w, opt.workspace_allocator);
517
else // if (A.w == B.w)
520
A2.w = A.w * A.elempack;
522
A2.elemsize = A.elemsize / A.elempack;
526
if (outdims == 3 && A.dims == 2)
527
A2 = A.reshape(1, A.w, A.h, opt.workspace_allocator);
528
if (outdims == 4 && A.dims == 1)
530
if (A.w * A.elempack == B.c * B.elempack)
531
A2 = A.reshape(1, 1, 1, A.w, opt.workspace_allocator);
532
else // if (A.w == B.w)
535
A2.w = A.w * A.elempack;
537
A2.elemsize = A.elemsize / A.elempack;
541
if (outdims == 4 && A.dims == 2)
542
A2 = A.reshape(1, 1, A.w, A.h, opt.workspace_allocator);
543
if (outdims == 4 && A.dims == 3)
544
A2 = A.reshape(1, A.w, A.h, A.c, opt.workspace_allocator);
546
if (B.dims < outdims)
551
if (B.w * B.elempack == A.h * A.elempack)
552
B2 = B.reshape(1, B.w, opt.workspace_allocator);
553
else // if (B.w == A.w)
556
B2.w = B.w * B.elempack;
558
B2.elemsize = B.elemsize / B.elempack;
562
if (outdims == 3 && B.dims == 1)
564
if (B.w * B.elempack == A.c * A.elempack)
565
B2 = B.reshape(1, 1, B.w, opt.workspace_allocator);
566
else // if (B.w == A.w)
569
B2.w = B.w * B.elempack;
571
B2.elemsize = B.elemsize / B.elempack;
575
if (outdims == 3 && B.dims == 2)
576
B2 = B.reshape(1, B.w, B.h, opt.workspace_allocator);
577
if (outdims == 4 && B.dims == 1)
579
if (B.w * B.elempack == A.c * A.elempack)
580
B2 = B.reshape(1, 1, 1, B.w, opt.workspace_allocator);
581
else // if (B.w == A.w)
584
B2.w = B.w * B.elempack;
586
B2.elemsize = B.elemsize / B.elempack;
590
if (outdims == 4 && B.dims == 2)
591
B2 = B.reshape(1, 1, B.w, B.h, opt.workspace_allocator);
592
if (outdims == 4 && B.dims == 3)
593
B2 = B.reshape(1, B.w, B.h, B.c, opt.workspace_allocator);
596
const int outw = std::max(A2.w, B2.w);
597
const int outh = std::max(A2.h, B2.h);
598
const int outd = std::max(A2.d, B2.d);
599
const int outc = std::max(A2.c, B2.c);
600
const size_t out_elemsize = std::max(A2.elemsize, B2.elemsize);
601
const int out_elempack = std::max(A2.elempack, B2.elempack);
603
Mat& top_blob = top_blobs[0];
606
top_blob.create(outw, out_elemsize, out_elempack, opt.blob_allocator);
610
top_blob.create(outw, outh, out_elemsize, out_elempack, opt.blob_allocator);
614
top_blob.create(outw, outh, outc, out_elemsize, out_elempack, opt.blob_allocator);
618
top_blob.create(outw, outh, outd, outc, out_elemsize, out_elempack, opt.blob_allocator);
620
if (top_blob.empty())
623
const bool a_pack_is_lower = A2.elempack < B2.elempack;
624
const bool a_pack_is_equal = A2.elempack == B2.elempack;
625
const bool a_size_is_lower = A2.w * A2.h * A2.d * A2.c * A2.elempack < B2.w * B2.h * B2.d * B2.c * B2.elempack;
626
if (a_pack_is_lower || (a_pack_is_equal && a_size_is_lower))
628
binary_op_broadcast(B2, A2, top_blob, get_reverse_op_type(op_type), opt);
632
binary_op_broadcast(A2, B2, top_blob, op_type, opt);
638
int BinaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
640
binary_op_scalar_inplace(bottom_top_blob, b, op_type, opt);