ncnn

Форк
0
/
yolox.cpp 
424 строки · 12.9 Кб
1
// This file is wirtten base on the following file:
2
// https://github.com/Tencent/ncnn/blob/master/examples/yolov5.cpp
3
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
5
// in compliance with the License. You may obtain a copy of the License at
6
//
7
// https://opensource.org/licenses/BSD-3-Clause
8
//
9
// Unless required by applicable law or agreed to in writing, software distributed
10
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
11
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
12
// specific language governing permissions and limitations under the License.
13
// ------------------------------------------------------------------------------
14
// Copyright (C) 2020-2021, Megvii Inc. All rights reserved.
15

16
#include "layer.h"
17
#include "net.h"
18

19
#if defined(USE_NCNN_SIMPLEOCV)
20
#include "simpleocv.h"
21
#else
22
#include <opencv2/core/core.hpp>
23
#include <opencv2/highgui/highgui.hpp>
24
#include <opencv2/imgproc/imgproc.hpp>
25
#endif
26
#include <float.h>
27
#include <stdio.h>
28
#include <vector>
29

30
#define YOLOX_NMS_THRESH  0.45 // nms threshold
31
#define YOLOX_CONF_THRESH 0.25 // threshold of bounding box prob
32
#define YOLOX_TARGET_SIZE 640  // target image size after resize, might use 416 for small model
33

34
// YOLOX use the same focus in yolov5
35
class YoloV5Focus : public ncnn::Layer
36
{
37
public:
38
    YoloV5Focus()
39
    {
40
        one_blob_only = true;
41
    }
42

43
    virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const
44
    {
45
        int w = bottom_blob.w;
46
        int h = bottom_blob.h;
47
        int channels = bottom_blob.c;
48

49
        int outw = w / 2;
50
        int outh = h / 2;
51
        int outc = channels * 4;
52

53
        top_blob.create(outw, outh, outc, 4u, 1, opt.blob_allocator);
54
        if (top_blob.empty())
55
            return -100;
56

57
        #pragma omp parallel for num_threads(opt.num_threads)
58
        for (int p = 0; p < outc; p++)
59
        {
60
            const float* ptr = bottom_blob.channel(p % channels).row((p / channels) % 2) + ((p / channels) / 2);
61
            float* outptr = top_blob.channel(p);
62

63
            for (int i = 0; i < outh; i++)
64
            {
65
                for (int j = 0; j < outw; j++)
66
                {
67
                    *outptr = *ptr;
68

69
                    outptr += 1;
70
                    ptr += 2;
71
                }
72

73
                ptr += w;
74
            }
75
        }
76

77
        return 0;
78
    }
79
};
80

81
DEFINE_LAYER_CREATOR(YoloV5Focus)
82

83
struct Object
84
{
85
    cv::Rect_<float> rect;
86
    int label;
87
    float prob;
88
};
89

90
struct GridAndStride
91
{
92
    int grid0;
93
    int grid1;
94
    int stride;
95
};
96

97
static inline float intersection_area(const Object& a, const Object& b)
98
{
99
    cv::Rect_<float> inter = a.rect & b.rect;
100
    return inter.area();
101
}
102

103
static void qsort_descent_inplace(std::vector<Object>& faceobjects, int left, int right)
104
{
105
    int i = left;
106
    int j = right;
107
    float p = faceobjects[(left + right) / 2].prob;
108

109
    while (i <= j)
110
    {
111
        while (faceobjects[i].prob > p)
112
            i++;
113

114
        while (faceobjects[j].prob < p)
115
            j--;
116

117
        if (i <= j)
118
        {
119
            // swap
120
            std::swap(faceobjects[i], faceobjects[j]);
121

122
            i++;
123
            j--;
124
        }
125
    }
126

127
    #pragma omp parallel sections
128
    {
129
        #pragma omp section
130
        {
131
            if (left < j) qsort_descent_inplace(faceobjects, left, j);
132
        }
133
        #pragma omp section
134
        {
135
            if (i < right) qsort_descent_inplace(faceobjects, i, right);
136
        }
137
    }
138
}
139

140
static void qsort_descent_inplace(std::vector<Object>& objects)
141
{
142
    if (objects.empty())
143
        return;
144

145
    qsort_descent_inplace(objects, 0, objects.size() - 1);
146
}
147

148
static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
149
{
150
    picked.clear();
151

152
    const int n = faceobjects.size();
153

154
    std::vector<float> areas(n);
155
    for (int i = 0; i < n; i++)
156
    {
157
        areas[i] = faceobjects[i].rect.area();
158
    }
159

160
    for (int i = 0; i < n; i++)
161
    {
162
        const Object& a = faceobjects[i];
163

164
        int keep = 1;
165
        for (int j = 0; j < (int)picked.size(); j++)
166
        {
167
            const Object& b = faceobjects[picked[j]];
168

169
            if (!agnostic && a.label != b.label)
170
                continue;
171

172
            // intersection over union
173
            float inter_area = intersection_area(a, b);
174
            float union_area = areas[i] + areas[picked[j]] - inter_area;
175
            // float IoU = inter_area / union_area
176
            if (inter_area / union_area > nms_threshold)
177
                keep = 0;
178
        }
179

180
        if (keep)
181
            picked.push_back(i);
182
    }
183
}
184

185
static void generate_grids_and_stride(const int target_w, const int target_h, std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
186
{
187
    for (int i = 0; i < (int)strides.size(); i++)
188
    {
189
        int stride = strides[i];
190
        int num_grid_w = target_w / stride;
191
        int num_grid_h = target_h / stride;
192
        for (int g1 = 0; g1 < num_grid_h; g1++)
193
        {
194
            for (int g0 = 0; g0 < num_grid_w; g0++)
195
            {
196
                GridAndStride gs;
197
                gs.grid0 = g0;
198
                gs.grid1 = g1;
199
                gs.stride = stride;
200
                grid_strides.push_back(gs);
201
            }
202
        }
203
    }
204
}
205

206
static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides, const ncnn::Mat& feat_blob, float prob_threshold, std::vector<Object>& objects)
207
{
208
    const int num_grid = feat_blob.h;
209
    const int num_class = feat_blob.w - 5;
210
    const int num_anchors = grid_strides.size();
211

212
    const float* feat_ptr = feat_blob.channel(0);
213
    for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++)
214
    {
215
        const int grid0 = grid_strides[anchor_idx].grid0;
216
        const int grid1 = grid_strides[anchor_idx].grid1;
217
        const int stride = grid_strides[anchor_idx].stride;
218

219
        // yolox/models/yolo_head.py decode logic
220
        //  outputs[..., :2] = (outputs[..., :2] + grids) * strides
221
        //  outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
222
        float x_center = (feat_ptr[0] + grid0) * stride;
223
        float y_center = (feat_ptr[1] + grid1) * stride;
224
        float w = exp(feat_ptr[2]) * stride;
225
        float h = exp(feat_ptr[3]) * stride;
226
        float x0 = x_center - w * 0.5f;
227
        float y0 = y_center - h * 0.5f;
228

229
        float box_objectness = feat_ptr[4];
230
        for (int class_idx = 0; class_idx < num_class; class_idx++)
231
        {
232
            float box_cls_score = feat_ptr[5 + class_idx];
233
            float box_prob = box_objectness * box_cls_score;
234
            if (box_prob > prob_threshold)
235
            {
236
                Object obj;
237
                obj.rect.x = x0;
238
                obj.rect.y = y0;
239
                obj.rect.width = w;
240
                obj.rect.height = h;
241
                obj.label = class_idx;
242
                obj.prob = box_prob;
243

244
                objects.push_back(obj);
245
            }
246

247
        } // class loop
248
        feat_ptr += feat_blob.w;
249

250
    } // point anchor loop
251
}
252

253
static int detect_yolox(const cv::Mat& bgr, std::vector<Object>& objects)
254
{
255
    ncnn::Net yolox;
256

257
    yolox.opt.use_vulkan_compute = true;
258
    // yolox.opt.use_bf16_storage = true;
259

260
    // Focus in yolov5
261
    yolox.register_custom_layer("YoloV5Focus", YoloV5Focus_layer_creator);
262

263
    // original pretrained model from https://github.com/Megvii-BaseDetection/YOLOX
264
    // ncnn model param: https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s_ncnn.tar.gz
265
    // NOTE that newest version YOLOX remove normalization of model (minus mean and then div by std),
266
    // which might cause your model outputs becoming a total mess, plz check carefully.
267
    if (yolox.load_param("yolox.param"))
268
        exit(-1);
269
    if (yolox.load_model("yolox.bin"))
270
        exit(-1);
271

272
    int img_w = bgr.cols;
273
    int img_h = bgr.rows;
274

275
    int w = img_w;
276
    int h = img_h;
277
    float scale = 1.f;
278
    if (w > h)
279
    {
280
        scale = (float)YOLOX_TARGET_SIZE / w;
281
        w = YOLOX_TARGET_SIZE;
282
        h = h * scale;
283
    }
284
    else
285
    {
286
        scale = (float)YOLOX_TARGET_SIZE / h;
287
        h = YOLOX_TARGET_SIZE;
288
        w = w * scale;
289
    }
290
    ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR, img_w, img_h, w, h);
291

292
    // pad to YOLOX_TARGET_SIZE rectangle
293
    int wpad = (w + 31) / 32 * 32 - w;
294
    int hpad = (h + 31) / 32 * 32 - h;
295
    ncnn::Mat in_pad;
296
    // different from yolov5, yolox only pad on bottom and right side,
297
    // which means users don't need to extra padding info to decode boxes coordinate.
298
    ncnn::copy_make_border(in, in_pad, 0, hpad, 0, wpad, ncnn::BORDER_CONSTANT, 114.f);
299

300
    ncnn::Extractor ex = yolox.create_extractor();
301

302
    ex.input("images", in_pad);
303

304
    std::vector<Object> proposals;
305

306
    {
307
        ncnn::Mat out;
308
        ex.extract("output", out);
309

310
        static const int stride_arr[] = {8, 16, 32}; // might have stride=64 in YOLOX
311
        std::vector<int> strides(stride_arr, stride_arr + sizeof(stride_arr) / sizeof(stride_arr[0]));
312
        std::vector<GridAndStride> grid_strides;
313
        generate_grids_and_stride(in_pad.w, in_pad.h, strides, grid_strides);
314
        generate_yolox_proposals(grid_strides, out, YOLOX_CONF_THRESH, proposals);
315
    }
316

317
    // sort all proposals by score from highest to lowest
318
    qsort_descent_inplace(proposals);
319

320
    // apply nms with nms_threshold
321
    std::vector<int> picked;
322
    nms_sorted_bboxes(proposals, picked, YOLOX_NMS_THRESH);
323

324
    int count = picked.size();
325

326
    objects.resize(count);
327
    for (int i = 0; i < count; i++)
328
    {
329
        objects[i] = proposals[picked[i]];
330

331
        // adjust offset to original unpadded
332
        float x0 = (objects[i].rect.x) / scale;
333
        float y0 = (objects[i].rect.y) / scale;
334
        float x1 = (objects[i].rect.x + objects[i].rect.width) / scale;
335
        float y1 = (objects[i].rect.y + objects[i].rect.height) / scale;
336

337
        // clip
338
        x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
339
        y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
340
        x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
341
        y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);
342

343
        objects[i].rect.x = x0;
344
        objects[i].rect.y = y0;
345
        objects[i].rect.width = x1 - x0;
346
        objects[i].rect.height = y1 - y0;
347
    }
348

349
    return 0;
350
}
351

352
static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
353
{
354
    static const char* class_names[] = {
355
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
356
        "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
357
        "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
358
        "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
359
        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
360
        "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
361
        "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
362
        "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
363
        "hair drier", "toothbrush"
364
    };
365

366
    cv::Mat image = bgr.clone();
367

368
    for (size_t i = 0; i < objects.size(); i++)
369
    {
370
        const Object& obj = objects[i];
371

372
        fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
373
                obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
374

375
        cv::rectangle(image, obj.rect, cv::Scalar(255, 0, 0));
376

377
        char text[256];
378
        sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
379

380
        int baseLine = 0;
381
        cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
382

383
        int x = obj.rect.x;
384
        int y = obj.rect.y - label_size.height - baseLine;
385
        if (y < 0)
386
            y = 0;
387
        if (x + label_size.width > image.cols)
388
            x = image.cols - label_size.width;
389

390
        cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
391
                      cv::Scalar(255, 255, 255), -1);
392

393
        cv::putText(image, text, cv::Point(x, y + label_size.height),
394
                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
395
    }
396

397
    cv::imshow("image", image);
398
    cv::waitKey(0);
399
}
400

401
int main(int argc, char** argv)
402
{
403
    if (argc != 2)
404
    {
405
        fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
406
        return -1;
407
    }
408

409
    const char* imagepath = argv[1];
410

411
    cv::Mat m = cv::imread(imagepath, 1);
412
    if (m.empty())
413
    {
414
        fprintf(stderr, "cv::imread %s failed\n", imagepath);
415
        return -1;
416
    }
417

418
    std::vector<Object> objects;
419
    detect_yolox(m, objects);
420

421
    draw_objects(m, objects);
422

423
    return 0;
424
}
425

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.