ncnn

Форк
0
/
yolact.cpp 
549 строк · 15.7 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "net.h"
16

17
#if defined(USE_NCNN_SIMPLEOCV)
18
#include "simpleocv.h"
19
#else
20
#include <opencv2/core/core.hpp>
21
#include <opencv2/highgui/highgui.hpp>
22
#include <opencv2/imgproc/imgproc.hpp>
23
#endif
24
#include <stdio.h>
25
#include <vector>
26

27
struct Object
28
{
29
    cv::Rect_<float> rect;
30
    int label;
31
    float prob;
32
    std::vector<float> maskdata;
33
    cv::Mat mask;
34
};
35

36
static inline float intersection_area(const Object& a, const Object& b)
37
{
38
    cv::Rect_<float> inter = a.rect & b.rect;
39
    return inter.area();
40
}
41

42
static void qsort_descent_inplace(std::vector<Object>& objects, int left, int right)
43
{
44
    int i = left;
45
    int j = right;
46
    float p = objects[(left + right) / 2].prob;
47

48
    while (i <= j)
49
    {
50
        while (objects[i].prob > p)
51
            i++;
52

53
        while (objects[j].prob < p)
54
            j--;
55

56
        if (i <= j)
57
        {
58
            // swap
59
            std::swap(objects[i], objects[j]);
60

61
            i++;
62
            j--;
63
        }
64
    }
65

66
    #pragma omp parallel sections
67
    {
68
        #pragma omp section
69
        {
70
            if (left < j) qsort_descent_inplace(objects, left, j);
71
        }
72
        #pragma omp section
73
        {
74
            if (i < right) qsort_descent_inplace(objects, i, right);
75
        }
76
    }
77
}
78

79
static void qsort_descent_inplace(std::vector<Object>& objects)
80
{
81
    if (objects.empty())
82
        return;
83

84
    qsort_descent_inplace(objects, 0, objects.size() - 1);
85
}
86

87
static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
88
{
89
    picked.clear();
90

91
    const int n = faceobjects.size();
92

93
    std::vector<float> areas(n);
94
    for (int i = 0; i < n; i++)
95
    {
96
        areas[i] = faceobjects[i].rect.area();
97
    }
98

99
    for (int i = 0; i < n; i++)
100
    {
101
        const Object& a = faceobjects[i];
102

103
        int keep = 1;
104
        for (int j = 0; j < (int)picked.size(); j++)
105
        {
106
            const Object& b = faceobjects[picked[j]];
107

108
            if (!agnostic && a.label != b.label)
109
                continue;
110

111
            // intersection over union
112
            float inter_area = intersection_area(a, b);
113
            float union_area = areas[i] + areas[picked[j]] - inter_area;
114
            // float IoU = inter_area / union_area
115
            if (inter_area / union_area > nms_threshold)
116
                keep = 0;
117
        }
118

119
        if (keep)
120
            picked.push_back(i);
121
    }
122
}
123

124
static int detect_yolact(const cv::Mat& bgr, std::vector<Object>& objects)
125
{
126
    ncnn::Net yolact;
127

128
    yolact.opt.use_vulkan_compute = true;
129

130
    // original model converted from https://github.com/dbolya/yolact
131
    // yolact_resnet50_54_800000.pth
132
    // the ncnn model https://github.com/nihui/ncnn-assets/tree/master/models
133
    if (yolact.load_param("yolact.param"))
134
        exit(-1);
135
    if (yolact.load_model("yolact.bin"))
136
        exit(-1);
137

138
    const int target_size = 550;
139

140
    int img_w = bgr.cols;
141
    int img_h = bgr.rows;
142

143
    ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, target_size, target_size);
144

145
    const float mean_vals[3] = {123.68f, 116.78f, 103.94f};
146
    const float norm_vals[3] = {1.0 / 58.40f, 1.0 / 57.12f, 1.0 / 57.38f};
147
    in.substract_mean_normalize(mean_vals, norm_vals);
148

149
    ncnn::Extractor ex = yolact.create_extractor();
150

151
    ex.input("input.1", in);
152

153
    ncnn::Mat maskmaps;
154
    ncnn::Mat location;
155
    ncnn::Mat mask;
156
    ncnn::Mat confidence;
157

158
    ex.extract("619", maskmaps); // 138x138 x 32
159

160
    ex.extract("816", location);   // 4 x 19248
161
    ex.extract("818", mask);       // maskdim 32 x 19248
162
    ex.extract("820", confidence); // 81 x 19248
163

164
    int num_class = confidence.w;
165
    int num_priors = confidence.h;
166

167
    // make priorbox
168
    ncnn::Mat priorbox(4, num_priors);
169
    {
170
        const int conv_ws[5] = {69, 35, 18, 9, 5};
171
        const int conv_hs[5] = {69, 35, 18, 9, 5};
172

173
        const float aspect_ratios[3] = {1.f, 0.5f, 2.f};
174
        const float scales[5] = {24.f, 48.f, 96.f, 192.f, 384.f};
175

176
        float* pb = priorbox;
177

178
        for (int p = 0; p < 5; p++)
179
        {
180
            int conv_w = conv_ws[p];
181
            int conv_h = conv_hs[p];
182

183
            float scale = scales[p];
184

185
            for (int i = 0; i < conv_h; i++)
186
            {
187
                for (int j = 0; j < conv_w; j++)
188
                {
189
                    // +0.5 because priors are in center-size notation
190
                    float cx = (j + 0.5f) / conv_w;
191
                    float cy = (i + 0.5f) / conv_h;
192

193
                    for (int k = 0; k < 3; k++)
194
                    {
195
                        float ar = aspect_ratios[k];
196

197
                        ar = sqrt(ar);
198

199
                        float w = scale * ar / 550;
200
                        float h = scale / ar / 550;
201

202
                        // This is for backward compatibility with a bug where I made everything square by accident
203
                        // cfg.backbone.use_square_anchors:
204
                        h = w;
205

206
                        pb[0] = cx;
207
                        pb[1] = cy;
208
                        pb[2] = w;
209
                        pb[3] = h;
210

211
                        pb += 4;
212
                    }
213
                }
214
            }
215
        }
216
    }
217

218
    const float confidence_thresh = 0.05f;
219
    const float nms_threshold = 0.5f;
220
    const int keep_top_k = 200;
221

222
    std::vector<std::vector<Object> > class_candidates;
223
    class_candidates.resize(num_class);
224

225
    for (int i = 0; i < num_priors; i++)
226
    {
227
        const float* conf = confidence.row(i);
228
        const float* loc = location.row(i);
229
        const float* pb = priorbox.row(i);
230
        const float* maskdata = mask.row(i);
231

232
        // find class id with highest score
233
        // start from 1 to skip background
234
        int label = 0;
235
        float score = 0.f;
236
        for (int j = 1; j < num_class; j++)
237
        {
238
            float class_score = conf[j];
239
            if (class_score > score)
240
            {
241
                label = j;
242
                score = class_score;
243
            }
244
        }
245

246
        // ignore background or low score
247
        if (label == 0 || score <= confidence_thresh)
248
            continue;
249

250
        // CENTER_SIZE
251
        float var[4] = {0.1f, 0.1f, 0.2f, 0.2f};
252

253
        float pb_cx = pb[0];
254
        float pb_cy = pb[1];
255
        float pb_w = pb[2];
256
        float pb_h = pb[3];
257

258
        float bbox_cx = var[0] * loc[0] * pb_w + pb_cx;
259
        float bbox_cy = var[1] * loc[1] * pb_h + pb_cy;
260
        float bbox_w = (float)(exp(var[2] * loc[2]) * pb_w);
261
        float bbox_h = (float)(exp(var[3] * loc[3]) * pb_h);
262

263
        float obj_x1 = bbox_cx - bbox_w * 0.5f;
264
        float obj_y1 = bbox_cy - bbox_h * 0.5f;
265
        float obj_x2 = bbox_cx + bbox_w * 0.5f;
266
        float obj_y2 = bbox_cy + bbox_h * 0.5f;
267

268
        // clip
269
        obj_x1 = std::max(std::min(obj_x1 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
270
        obj_y1 = std::max(std::min(obj_y1 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
271
        obj_x2 = std::max(std::min(obj_x2 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
272
        obj_y2 = std::max(std::min(obj_y2 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
273

274
        // append object
275
        Object obj;
276
        obj.rect = cv::Rect_<float>(obj_x1, obj_y1, obj_x2 - obj_x1 + 1, obj_y2 - obj_y1 + 1);
277
        obj.label = label;
278
        obj.prob = score;
279
        obj.maskdata = std::vector<float>(maskdata, maskdata + mask.w);
280

281
        class_candidates[label].push_back(obj);
282
    }
283

284
    objects.clear();
285
    for (int i = 0; i < (int)class_candidates.size(); i++)
286
    {
287
        std::vector<Object>& candidates = class_candidates[i];
288

289
        qsort_descent_inplace(candidates);
290

291
        std::vector<int> picked;
292
        nms_sorted_bboxes(candidates, picked, nms_threshold);
293

294
        for (int j = 0; j < (int)picked.size(); j++)
295
        {
296
            int z = picked[j];
297
            objects.push_back(candidates[z]);
298
        }
299
    }
300

301
    qsort_descent_inplace(objects);
302

303
    // keep_top_k
304
    if (keep_top_k < (int)objects.size())
305
    {
306
        objects.resize(keep_top_k);
307
    }
308

309
    // generate mask
310
    for (int i = 0; i < (int)objects.size(); i++)
311
    {
312
        Object& obj = objects[i];
313

314
        cv::Mat mask(maskmaps.h, maskmaps.w, CV_32FC1);
315
        {
316
            mask = cv::Scalar(0.f);
317

318
            for (int p = 0; p < maskmaps.c; p++)
319
            {
320
                const float* maskmap = maskmaps.channel(p);
321
                float coeff = obj.maskdata[p];
322
                float* mp = (float*)mask.data;
323

324
                // mask += m * coeff
325
                for (int j = 0; j < maskmaps.w * maskmaps.h; j++)
326
                {
327
                    mp[j] += maskmap[j] * coeff;
328
                }
329
            }
330
        }
331

332
        cv::Mat mask2;
333
        cv::resize(mask, mask2, cv::Size(img_w, img_h));
334

335
        // crop obj box and binarize
336
        obj.mask = cv::Mat(img_h, img_w, CV_8UC1);
337
        {
338
            obj.mask = cv::Scalar(0);
339

340
            for (int y = 0; y < img_h; y++)
341
            {
342
                if (y < obj.rect.y || y > obj.rect.y + obj.rect.height)
343
                    continue;
344

345
                const float* mp2 = mask2.ptr<const float>(y);
346
                uchar* bmp = obj.mask.ptr<uchar>(y);
347

348
                for (int x = 0; x < img_w; x++)
349
                {
350
                    if (x < obj.rect.x || x > obj.rect.x + obj.rect.width)
351
                        continue;
352

353
                    bmp[x] = mp2[x] > 0.5f ? 255 : 0;
354
                }
355
            }
356
        }
357
    }
358

359
    return 0;
360
}
361

362
static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
363
{
364
    static const char* class_names[] = {"background",
365
                                        "person", "bicycle", "car", "motorcycle", "airplane", "bus",
366
                                        "train", "truck", "boat", "traffic light", "fire hydrant",
367
                                        "stop sign", "parking meter", "bench", "bird", "cat", "dog",
368
                                        "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
369
                                        "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
370
                                        "skis", "snowboard", "sports ball", "kite", "baseball bat",
371
                                        "baseball glove", "skateboard", "surfboard", "tennis racket",
372
                                        "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
373
                                        "banana", "apple", "sandwich", "orange", "broccoli", "carrot",
374
                                        "hot dog", "pizza", "donut", "cake", "chair", "couch",
375
                                        "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
376
                                        "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
377
                                        "toaster", "sink", "refrigerator", "book", "clock", "vase",
378
                                        "scissors", "teddy bear", "hair drier", "toothbrush"
379
                                       };
380

381
    static const unsigned char colors[81][3] = {
382
        {56, 0, 255},
383
        {226, 255, 0},
384
        {0, 94, 255},
385
        {0, 37, 255},
386
        {0, 255, 94},
387
        {255, 226, 0},
388
        {0, 18, 255},
389
        {255, 151, 0},
390
        {170, 0, 255},
391
        {0, 255, 56},
392
        {255, 0, 75},
393
        {0, 75, 255},
394
        {0, 255, 169},
395
        {255, 0, 207},
396
        {75, 255, 0},
397
        {207, 0, 255},
398
        {37, 0, 255},
399
        {0, 207, 255},
400
        {94, 0, 255},
401
        {0, 255, 113},
402
        {255, 18, 0},
403
        {255, 0, 56},
404
        {18, 0, 255},
405
        {0, 255, 226},
406
        {170, 255, 0},
407
        {255, 0, 245},
408
        {151, 255, 0},
409
        {132, 255, 0},
410
        {75, 0, 255},
411
        {151, 0, 255},
412
        {0, 151, 255},
413
        {132, 0, 255},
414
        {0, 255, 245},
415
        {255, 132, 0},
416
        {226, 0, 255},
417
        {255, 37, 0},
418
        {207, 255, 0},
419
        {0, 255, 207},
420
        {94, 255, 0},
421
        {0, 226, 255},
422
        {56, 255, 0},
423
        {255, 94, 0},
424
        {255, 113, 0},
425
        {0, 132, 255},
426
        {255, 0, 132},
427
        {255, 170, 0},
428
        {255, 0, 188},
429
        {113, 255, 0},
430
        {245, 0, 255},
431
        {113, 0, 255},
432
        {255, 188, 0},
433
        {0, 113, 255},
434
        {255, 0, 0},
435
        {0, 56, 255},
436
        {255, 0, 113},
437
        {0, 255, 188},
438
        {255, 0, 94},
439
        {255, 0, 18},
440
        {18, 255, 0},
441
        {0, 255, 132},
442
        {0, 188, 255},
443
        {0, 245, 255},
444
        {0, 169, 255},
445
        {37, 255, 0},
446
        {255, 0, 151},
447
        {188, 0, 255},
448
        {0, 255, 37},
449
        {0, 255, 0},
450
        {255, 0, 170},
451
        {255, 0, 37},
452
        {255, 75, 0},
453
        {0, 0, 255},
454
        {255, 207, 0},
455
        {255, 0, 226},
456
        {255, 245, 0},
457
        {188, 255, 0},
458
        {0, 255, 18},
459
        {0, 255, 75},
460
        {0, 255, 151},
461
        {255, 56, 0},
462
        {245, 255, 0}
463
    };
464

465
    cv::Mat image = bgr.clone();
466

467
    int color_index = 0;
468

469
    for (size_t i = 0; i < objects.size(); i++)
470
    {
471
        const Object& obj = objects[i];
472

473
        if (obj.prob < 0.15)
474
            continue;
475

476
        fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
477
                obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
478

479
        const unsigned char* color = colors[color_index % 81];
480
        color_index++;
481

482
        cv::rectangle(image, obj.rect, cv::Scalar(color[0], color[1], color[2]));
483

484
        char text[256];
485
        sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
486

487
        int baseLine = 0;
488
        cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
489

490
        int x = obj.rect.x;
491
        int y = obj.rect.y - label_size.height - baseLine;
492
        if (y < 0)
493
            y = 0;
494
        if (x + label_size.width > image.cols)
495
            x = image.cols - label_size.width;
496

497
        cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
498
                      cv::Scalar(255, 255, 255), -1);
499

500
        cv::putText(image, text, cv::Point(x, y + label_size.height),
501
                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
502

503
        // draw mask
504
        for (int y = 0; y < image.rows; y++)
505
        {
506
            const uchar* mp = obj.mask.ptr(y);
507
            uchar* p = image.ptr(y);
508
            for (int x = 0; x < image.cols; x++)
509
            {
510
                if (mp[x] == 255)
511
                {
512
                    p[0] = cv::saturate_cast<uchar>(p[0] * 0.5 + color[0] * 0.5);
513
                    p[1] = cv::saturate_cast<uchar>(p[1] * 0.5 + color[1] * 0.5);
514
                    p[2] = cv::saturate_cast<uchar>(p[2] * 0.5 + color[2] * 0.5);
515
                }
516
                p += 3;
517
            }
518
        }
519
    }
520

521
    cv::imwrite("result.png", image);
522
    cv::imshow("image", image);
523
    cv::waitKey(0);
524
}
525

526
int main(int argc, char** argv)
527
{
528
    if (argc != 2)
529
    {
530
        fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
531
        return -1;
532
    }
533

534
    const char* imagepath = argv[1];
535

536
    cv::Mat m = cv::imread(imagepath, 1);
537
    if (m.empty())
538
    {
539
        fprintf(stderr, "cv::imread %s failed\n", imagepath);
540
        return -1;
541
    }
542

543
    std::vector<Object> objects;
544
    detect_yolact(m, objects);
545

546
    draw_objects(m, objects);
547

548
    return 0;
549
}
550

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.