ncnn

Форк
0
/
binaryop_mips.cpp 
645 строк · 20.2 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "binaryop_mips.h"
16

17
#if __mips_msa
18
#include <msa.h>
19
#include "msa_mathfun.h"
20
#endif // __mips_msa
21

22
namespace ncnn {
23

24
BinaryOp_mips::BinaryOp_mips()
25
{
26
#if __mips_msa
27
    support_packing = true;
28
#endif // __mips_msa
29
}
30

31
template<typename Op>
32
static void binary_op_vector_no_broadcast(const float* ptr, const float* ptr1, float* outptr, int size)
33
{
34
    const Op op;
35

36
    int i = 0;
37
#if __mips_msa
38
    for (; i + 3 < size; i += 4)
39
    {
40
        __builtin_prefetch(ptr + 16);
41
        __builtin_prefetch(ptr1 + 16);
42
        v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
43
        v4f32 _b = (v4f32)__msa_ld_w(ptr1, 0);
44
        v4f32 _outp = op(_p, _b);
45
        __msa_st_w((v4i32)_outp, outptr, 0);
46
        ptr += 4;
47
        ptr1 += 4;
48
        outptr += 4;
49
    }
50
#endif // __mips_msa
51
    for (; i < size; i++)
52
    {
53
        *outptr = op(*ptr, *ptr1);
54
        ptr += 1;
55
        ptr1 += 1;
56
        outptr += 1;
57
    }
58
}
59

60
template<typename Op>
61
static void binary_op_vector_broadcast_b(const float* ptr, const float* ptr1, float* outptr, int size, int elempack)
62
{
63
    const Op op;
64

65
    const float b = *ptr1;
66
#if __mips_msa
67
    v4f32 _b_128 = (elempack == 4) ? (v4f32)__msa_ld_w(ptr1, 0) : __msa_fill_w_f32(b);
68
#endif // __mips_msa
69

70
    int i = 0;
71
#if __mips_msa
72
    for (; i + 3 < size; i += 4)
73
    {
74
        __builtin_prefetch(ptr + 16);
75
        v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
76
        v4f32 _outp = op(_p, _b_128);
77
        __msa_st_w((v4i32)_outp, outptr, 0);
78
        ptr += 4;
79
        outptr += 4;
80
    }
81
#endif // __mips_msa
82
    for (; i < size; i++)
83
    {
84
        *outptr = op(*ptr, b);
85
        ptr += 1;
86
        outptr += 1;
87
    }
88
}
89

90
template<typename Op>
91
static void binary_op_vector_broadcast_a(const float* ptr, const float* ptr1, float* outptr, int size, int elempack)
92
{
93
    const Op op;
94

95
    const float a = *ptr;
96
#if __mips_msa
97
    v4f32 _a_128 = (elempack == 4) ? (v4f32)__msa_ld_w(ptr, 0) : __msa_fill_w_f32(a);
98
#endif // __mips_msa
99

100
    int i = 0;
101
#if __mips_msa
102
    for (; i + 3 < size; i += 4)
103
    {
104
        __builtin_prefetch(ptr1 + 16);
105
        v4f32 _b = (v4f32)__msa_ld_w(ptr1, 0);
106
        v4f32 _outp = op(_a_128, _b);
107
        __msa_st_w((v4i32)_outp, outptr, 0);
108
        ptr1 += 4;
109
        outptr += 4;
110
    }
111
#endif // __mips_msa
112
    for (; i < size; i++)
113
    {
114
        *outptr = op(a, *ptr1);
115
        ptr1 += 1;
116
        outptr += 1;
117
    }
118
}
119

120
template<typename Op>
121
static void binary_op_vector_broadcast_pb(const float* ptr, const float* ptr1, float* outptr, int w, int elempack)
122
{
123
    const Op op;
124

125
#if __mips_msa
126
    if (elempack == 4)
127
    {
128
        int i = 0;
129
        for (; i < w; i++)
130
        {
131
            __builtin_prefetch(ptr + 16);
132
            v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
133
            v4f32 _b = __msa_fill_w_f32(*ptr1);
134
            v4f32 _outp = op(_p, _b);
135
            __msa_st_w((v4i32)_outp, outptr, 0);
136
            ptr += 4;
137
            ptr1 += 1;
138
            outptr += 4;
139
        }
140
    }
141
#endif // __mips_msa
142
}
143

144
template<typename Op>
145
static void binary_op_vector_broadcast_pb_b(const float* ptr, const float* ptr1, float* outptr, int w, int elempack)
146
{
147
    const Op op;
148

149
    const int size = w * elempack;
150

151
    int i = 0;
152
#if __mips_msa
153
    v4f32 _b = __msa_fill_w_f32(*ptr1);
154
    for (; i + 3 < size; i += 4)
155
    {
156
        __builtin_prefetch(ptr + 16);
157
        v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
158
        v4f32 _outp = op(_p, _b);
159
        __msa_st_w((v4i32)_outp, outptr, 0);
160
        ptr += 4;
161
        outptr += 4;
162
    }
163
#endif // __mips_msa
164
}
165

166
template<typename Op>
167
static void binary_op_vector_broadcast_pb_a(const float* ptr, const float* ptr1, float* outptr, int w, int elempack)
168
{
169
    const Op op;
170

171
#if __mips_msa
172
    if (elempack == 4)
173
    {
174
        int i = 0;
175
        v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
176
        for (; i < w; i++)
177
        {
178
            v4f32 _b = __msa_fill_w_f32(*ptr1);
179
            v4f32 _outp = op(_p, _b);
180
            __msa_st_w((v4i32)_outp, outptr, 0);
181
            ptr1 += 1;
182
            outptr += 4;
183
        }
184
    }
185
#endif // __mips_msa
186
}
187

188
template<typename Op>
189
static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp)
190
{
191
    const int w = std::max(aw, bw);
192
    const int elempack = std::max(ap, bp);
193
    const int size = w * elempack;
194

195
    if (ap == bp)
196
    {
197
        if (aw == bw)
198
        {
199
            // no broadcast
200
            return binary_op_vector_no_broadcast<Op>(ptr, ptr1, outptr, size);
201
        }
202

203
        if (bw == 1)
204
        {
205
            // broadcast single b
206
            return binary_op_vector_broadcast_b<Op>(ptr, ptr1, outptr, size, elempack);
207
        }
208

209
        if (aw == 1)
210
        {
211
            // broadcast single a
212
            return binary_op_vector_broadcast_a<Op>(ptr, ptr1, outptr, size, elempack);
213
        }
214
    }
215

216
    if (bp == 1)
217
    {
218
        if (aw == bw)
219
        {
220
            // broadcast pack1 b
221
            return binary_op_vector_broadcast_pb<Op>(ptr, ptr1, outptr, w, elempack);
222
        }
223

224
        if (bw == 1)
225
        {
226
            // broadcast pack1 single b
227
            return binary_op_vector_broadcast_pb_b<Op>(ptr, ptr1, outptr, w, elempack);
228
        }
229

230
        if (aw == 1)
231
        {
232
            // broadcast single a and pack1 b
233
            return binary_op_vector_broadcast_pb_a<Op>(ptr, ptr1, outptr, w, elempack);
234
        }
235
    }
236

237
    // shall never reach here
238
}
239

240
template<typename Op>
241
static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt)
242
{
243
    Op op;
244

245
    const int channels = a.c;
246
    const int size = a.w * a.h * a.d * a.elempack;
247

248
    #pragma omp parallel for num_threads(opt.num_threads)
249
    for (int q = 0; q < channels; q++)
250
    {
251
        float* ptr = a.channel(q);
252

253
        int i = 0;
254
#if __mips_msa
255
        v4f32 _b = __msa_fill_w_f32(b);
256
        for (; i + 3 < size; i += 4)
257
        {
258
            __builtin_prefetch(ptr + 16);
259
            v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
260
            _p = op(_p, _b);
261
            __msa_st_w((v4i32)_p, ptr, 0);
262
            ptr += 4;
263
        }
264
#endif // __mips_msa
265
        for (; i < size; i++)
266
        {
267
            *ptr = op(*ptr, b);
268
            ptr++;
269
        }
270
    }
271

272
    return 0;
273
}
274

275
namespace BinaryOp_mips_functor {
276

277
#if __mips_msa
278
#define MAKE_FUNCTION(NAME, IMPL, IMPL4)                       \
279
    struct NAME                                                \
280
    {                                                          \
281
        float operator()(const float& x, const float& y) const \
282
        {                                                      \
283
            return IMPL;                                       \
284
        }                                                      \
285
        v4f32 operator()(const v4f32& x, const v4f32& y) const \
286
        {                                                      \
287
            return IMPL4;                                      \
288
        }                                                      \
289
    };
290
#else
291
#define MAKE_FUNCTION(NAME, IMPL, IMPL4)                       \
292
    struct NAME                                                \
293
    {                                                          \
294
        float operator()(const float& x, const float& y) const \
295
        {                                                      \
296
            return IMPL;                                       \
297
        }                                                      \
298
    };
299
#endif // __mips_msa
300

301
// clang-format off
302
// *INDENT-OFF*
303
MAKE_FUNCTION(binary_op_add, x + y, __msa_fadd_w(x, y))
304
MAKE_FUNCTION(binary_op_sub, x - y, __msa_fsub_w(x, y))
305
MAKE_FUNCTION(binary_op_mul, x * y, __msa_fmul_w(x, y))
306
MAKE_FUNCTION(binary_op_div, x / y, __msa_fdiv_w(x, y))
307
MAKE_FUNCTION(binary_op_max, std::max(x, y), __msa_fmax_w(x, y))
308
MAKE_FUNCTION(binary_op_min, std::min(x, y), __msa_fmin_w(x, y))
309
MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y))
310
MAKE_FUNCTION(binary_op_rsub, y - x, __msa_fsub_w(y, x))
311
MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x))
312
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
313
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
314
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
315
// *INDENT-ON*
316
// clang-format on
317

318
#undef MAKE_FUNCTION
319

320
} // namespace BinaryOp_mips_functor
321

322
static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type)
323
{
324
    using namespace BinaryOp_mips_functor;
325

326
    if (op_type == BinaryOp::Operation_ADD) return binary_op_vector<binary_op_add>(ptr, ptr1, outptr, aw, bw, ap, bp);
327
    if (op_type == BinaryOp::Operation_SUB) return binary_op_vector<binary_op_sub>(ptr, ptr1, outptr, aw, bw, ap, bp);
328
    if (op_type == BinaryOp::Operation_MUL) return binary_op_vector<binary_op_mul>(ptr, ptr1, outptr, aw, bw, ap, bp);
329
    if (op_type == BinaryOp::Operation_DIV) return binary_op_vector<binary_op_div>(ptr, ptr1, outptr, aw, bw, ap, bp);
330
    if (op_type == BinaryOp::Operation_MAX) return binary_op_vector<binary_op_max>(ptr, ptr1, outptr, aw, bw, ap, bp);
331
    if (op_type == BinaryOp::Operation_MIN) return binary_op_vector<binary_op_min>(ptr, ptr1, outptr, aw, bw, ap, bp);
332
    if (op_type == BinaryOp::Operation_POW) return binary_op_vector<binary_op_pow>(ptr, ptr1, outptr, aw, bw, ap, bp);
333
    if (op_type == BinaryOp::Operation_RSUB) return binary_op_vector<binary_op_rsub>(ptr, ptr1, outptr, aw, bw, ap, bp);
334
    if (op_type == BinaryOp::Operation_RDIV) return binary_op_vector<binary_op_rdiv>(ptr, ptr1, outptr, aw, bw, ap, bp);
335
    if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
336
    if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
337
    if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
338

339
    // should never reach here
340
}
341

342
static void binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt)
343
{
344
    const int channels = a.c;
345
    const int size = a.w * a.h * a.d * a.elempack;
346

347
    #pragma omp parallel for num_threads(opt.num_threads)
348
    for (int q = 0; q < channels; q++)
349
    {
350
        const float* ptr = a.channel(q);
351
        float* outptr = c.channel(q);
352

353
        binary_op_vector(ptr, &b, outptr, size, 1, 1, 1, op_type);
354
    }
355
}
356

357
static void binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt)
358
{
359
    const int channels = a.c;
360
    const int size = a.w * a.h * a.d * a.elempack;
361

362
    #pragma omp parallel for num_threads(opt.num_threads)
363
    for (int q = 0; q < channels; q++)
364
    {
365
        const float* ptr = a.channel(q);
366
        const float* ptr1 = b.channel(q);
367
        float* outptr = c.channel(q);
368

369
        binary_op_vector(ptr, ptr1, outptr, size, size, 1, 1, op_type);
370
    }
371
}
372

373
static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt)
374
{
375
    if (b.w * b.h * b.d * b.c * b.elempack == 1)
376
    {
377
        return binary_op_scalar(a, b[0], c, op_type, opt);
378
    }
379

380
    if (a.dims == b.dims && a.w == b.w && a.h == b.h && a.d == b.d && a.c == b.c && a.elempack == b.elempack)
381
    {
382
        return binary_op_no_broadcast(a, b, c, op_type, opt);
383
    }
384

385
    const int dims = c.dims;
386

387
    if (dims == 2)
388
    {
389
        const int h = c.h;
390

391
        #pragma omp parallel for num_threads(opt.num_threads)
392
        for (int y = 0; y < h; y++)
393
        {
394
            const int y0 = std::min(y, a.h - 1);
395
            const int y1 = std::min(y, b.h - 1);
396

397
            const float* ptr = a.row(y0);
398
            const float* ptr1 = b.row(y1);
399
            float* outptr = c.row(y);
400

401
            binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type);
402
        }
403
    }
404

405
    if (dims == 3 || dims == 4)
406
    {
407
        const int channels = c.c;
408

409
        #pragma omp parallel for num_threads(opt.num_threads)
410
        for (int q = 0; q < channels; q++)
411
        {
412
            const int q0 = std::min(q, a.c - 1);
413
            const int q1 = std::min(q, b.c - 1);
414

415
            if (b.d * b.h * b.w == 1)
416
            {
417
                const float* ptr = a.channel(q0);
418
                const float* ptr1 = b.channel(q1);
419
                float* outptr = c.channel(q);
420

421
                binary_op_vector(ptr, ptr1, outptr, a.w * a.h * a.d, 1, a.elempack, b.elempack, op_type);
422
                continue;
423
            }
424

425
            if (b.h * b.w == 1)
426
            {
427
                for (int z = 0; z < c.d; z++)
428
                {
429
                    const int z0 = std::min(z, a.d - 1);
430
                    const int z1 = std::min(z, b.d - 1);
431

432
                    const float* ptr = a.channel(q0).depth(z0);
433
                    const float* ptr1 = b.channel(q1).depth(z1);
434
                    float* outptr = c.channel(q).depth(z);
435

436
                    binary_op_vector(ptr, ptr1, outptr, a.w * a.h, 1, a.elempack, b.elempack, op_type);
437
                }
438
                continue;
439
            }
440

441
            for (int z = 0; z < c.d; z++)
442
            {
443
                const int z0 = std::min(z, a.d - 1);
444
                const int z1 = std::min(z, b.d - 1);
445

446
                for (int y = 0; y < c.h; y++)
447
                {
448
                    const int y0 = std::min(y, a.h - 1);
449
                    const int y1 = std::min(y, b.h - 1);
450

451
                    const float* ptr = a.channel(q0).depth(z0).row(y0);
452
                    const float* ptr1 = b.channel(q1).depth(z1).row(y1);
453
                    float* outptr = c.channel(q).depth(z).row(y);
454

455
                    binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type);
456
                }
457
            }
458
        }
459
    }
460
}
461

462
static void binary_op_scalar_inplace(Mat& a, float b, int op_type, const Option& opt)
463
{
464
    const int channels = a.c;
465
    const int size = a.w * a.h * a.d * a.elempack;
466

467
    #pragma omp parallel for num_threads(opt.num_threads)
468
    for (int q = 0; q < channels; q++)
469
    {
470
        float* ptr = a.channel(q);
471

472
        binary_op_vector(ptr, &b, ptr, size, 1, 1, 1, op_type);
473
    }
474
}
475

476
static int get_reverse_op_type(int op_type)
477
{
478
    if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
479
    if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
480
    if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
481
    if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
482
    if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
483
    if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
484
    if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
485
    if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
486
    return op_type;
487
}
488

489
int BinaryOp_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
490
{
491
    const Mat& A = bottom_blobs[0];
492
    const Mat& B = bottom_blobs[1];
493
    const int outdims = std::max(A.dims, B.dims);
494

495
    Mat A2 = A;
496
    Mat B2 = B;
497
    if (A.dims < outdims)
498
    {
499
        // expand inner axes
500
        if (outdims == 2)
501
        {
502
            if (A.w * A.elempack == B.h * B.elempack)
503
                A2 = A.reshape(1, A.w, opt.workspace_allocator);
504
            else // if (A.w == B.w)
505
            {
506
                A2.dims = 2;
507
                A2.w = A.w * A.elempack;
508
                A2.elempack = 1;
509
                A2.elemsize = A.elemsize / A.elempack;
510
                A2.cstep = A2.w;
511
            }
512
        }
513
        if (outdims == 3 && A.dims == 1)
514
        {
515
            if (A.w * A.elempack == B.c * B.elempack)
516
                A2 = A.reshape(1, 1, A.w, opt.workspace_allocator);
517
            else // if (A.w == B.w)
518
            {
519
                A2.dims = 3;
520
                A2.w = A.w * A.elempack;
521
                A2.elempack = 1;
522
                A2.elemsize = A.elemsize / A.elempack;
523
                A2.cstep = A2.w;
524
            }
525
        }
526
        if (outdims == 3 && A.dims == 2)
527
            A2 = A.reshape(1, A.w, A.h, opt.workspace_allocator);
528
        if (outdims == 4 && A.dims == 1)
529
        {
530
            if (A.w * A.elempack == B.c * B.elempack)
531
                A2 = A.reshape(1, 1, 1, A.w, opt.workspace_allocator);
532
            else // if (A.w == B.w)
533
            {
534
                A2.dims = 4;
535
                A2.w = A.w * A.elempack;
536
                A2.elempack = 1;
537
                A2.elemsize = A.elemsize / A.elempack;
538
                A2.cstep = A2.w;
539
            }
540
        }
541
        if (outdims == 4 && A.dims == 2)
542
            A2 = A.reshape(1, 1, A.w, A.h, opt.workspace_allocator);
543
        if (outdims == 4 && A.dims == 3)
544
            A2 = A.reshape(1, A.w, A.h, A.c, opt.workspace_allocator);
545
    }
546
    if (B.dims < outdims)
547
    {
548
        // expand inner axes
549
        if (outdims == 2)
550
        {
551
            if (B.w * B.elempack == A.h * A.elempack)
552
                B2 = B.reshape(1, B.w, opt.workspace_allocator);
553
            else // if (B.w == A.w)
554
            {
555
                B2.dims = 2;
556
                B2.w = B.w * B.elempack;
557
                B2.elempack = 1;
558
                B2.elemsize = B.elemsize / B.elempack;
559
                B2.cstep = B2.w;
560
            }
561
        }
562
        if (outdims == 3 && B.dims == 1)
563
        {
564
            if (B.w * B.elempack == A.c * A.elempack)
565
                B2 = B.reshape(1, 1, B.w, opt.workspace_allocator);
566
            else // if (B.w == A.w)
567
            {
568
                B2.dims = 3;
569
                B2.w = B.w * B.elempack;
570
                B2.elempack = 1;
571
                B2.elemsize = B.elemsize / B.elempack;
572
                B2.cstep = B2.w;
573
            }
574
        }
575
        if (outdims == 3 && B.dims == 2)
576
            B2 = B.reshape(1, B.w, B.h, opt.workspace_allocator);
577
        if (outdims == 4 && B.dims == 1)
578
        {
579
            if (B.w * B.elempack == A.c * A.elempack)
580
                B2 = B.reshape(1, 1, 1, B.w, opt.workspace_allocator);
581
            else // if (B.w == A.w)
582
            {
583
                B2.dims = 4;
584
                B2.w = B.w * B.elempack;
585
                B2.elempack = 1;
586
                B2.elemsize = B.elemsize / B.elempack;
587
                B2.cstep = B2.w;
588
            }
589
        }
590
        if (outdims == 4 && B.dims == 2)
591
            B2 = B.reshape(1, 1, B.w, B.h, opt.workspace_allocator);
592
        if (outdims == 4 && B.dims == 3)
593
            B2 = B.reshape(1, B.w, B.h, B.c, opt.workspace_allocator);
594
    }
595

596
    const int outw = std::max(A2.w, B2.w);
597
    const int outh = std::max(A2.h, B2.h);
598
    const int outd = std::max(A2.d, B2.d);
599
    const int outc = std::max(A2.c, B2.c);
600
    const size_t out_elemsize = std::max(A2.elemsize, B2.elemsize);
601
    const int out_elempack = std::max(A2.elempack, B2.elempack);
602

603
    Mat& top_blob = top_blobs[0];
604
    if (outdims == 1)
605
    {
606
        top_blob.create(outw, out_elemsize, out_elempack, opt.blob_allocator);
607
    }
608
    if (outdims == 2)
609
    {
610
        top_blob.create(outw, outh, out_elemsize, out_elempack, opt.blob_allocator);
611
    }
612
    if (outdims == 3)
613
    {
614
        top_blob.create(outw, outh, outc, out_elemsize, out_elempack, opt.blob_allocator);
615
    }
616
    if (outdims == 4)
617
    {
618
        top_blob.create(outw, outh, outd, outc, out_elemsize, out_elempack, opt.blob_allocator);
619
    }
620
    if (top_blob.empty())
621
        return -100;
622

623
    const bool a_pack_is_lower = A2.elempack < B2.elempack;
624
    const bool a_pack_is_equal = A2.elempack == B2.elempack;
625
    const bool a_size_is_lower = A2.w * A2.h * A2.d * A2.c * A2.elempack < B2.w * B2.h * B2.d * B2.c * B2.elempack;
626
    if (a_pack_is_lower || (a_pack_is_equal && a_size_is_lower))
627
    {
628
        binary_op_broadcast(B2, A2, top_blob, get_reverse_op_type(op_type), opt);
629
    }
630
    else
631
    {
632
        binary_op_broadcast(A2, B2, top_blob, op_type, opt);
633
    }
634

635
    return 0;
636
}
637

638
int BinaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
639
{
640
    binary_op_scalar_inplace(bottom_top_blob, b, op_type, opt);
641

642
    return 0;
643
}
644

645
} // namespace ncnn
646

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.