ncnn

Форк
0
/
squeezenet.cpp 
108 строк · 2.8 Кб
1
// Tencent is pleased to support the open source community by making ncnn available.
2
//
3
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4
//
5
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6
// in compliance with the License. You may obtain a copy of the License at
7
//
8
// https://opensource.org/licenses/BSD-3-Clause
9
//
10
// Unless required by applicable law or agreed to in writing, software distributed
11
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13
// specific language governing permissions and limitations under the License.
14

15
#include "net.h"
16

17
#include <algorithm>
18
#if defined(USE_NCNN_SIMPLEOCV)
19
#include "simpleocv.h"
20
#else
21
#include <opencv2/core/core.hpp>
22
#include <opencv2/highgui/highgui.hpp>
23
#endif
24
#include <stdio.h>
25
#include <vector>
26

27
static int detect_squeezenet(const cv::Mat& bgr, std::vector<float>& cls_scores)
28
{
29
    ncnn::Net squeezenet;
30

31
    squeezenet.opt.use_vulkan_compute = true;
32

33
    // the ncnn model https://github.com/nihui/ncnn-assets/tree/master/models
34
    if (squeezenet.load_param("squeezenet_v1.1.param"))
35
        exit(-1);
36
    if (squeezenet.load_model("squeezenet_v1.1.bin"))
37
        exit(-1);
38

39
    ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR, bgr.cols, bgr.rows, 227, 227);
40

41
    const float mean_vals[3] = {104.f, 117.f, 123.f};
42
    in.substract_mean_normalize(mean_vals, 0);
43

44
    ncnn::Extractor ex = squeezenet.create_extractor();
45

46
    ex.input("data", in);
47

48
    ncnn::Mat out;
49
    ex.extract("prob", out);
50

51
    cls_scores.resize(out.w);
52
    for (int j = 0; j < out.w; j++)
53
    {
54
        cls_scores[j] = out[j];
55
    }
56

57
    return 0;
58
}
59

60
static int print_topk(const std::vector<float>& cls_scores, int topk)
61
{
62
    // partial sort topk with index
63
    int size = cls_scores.size();
64
    std::vector<std::pair<float, int> > vec;
65
    vec.resize(size);
66
    for (int i = 0; i < size; i++)
67
    {
68
        vec[i] = std::make_pair(cls_scores[i], i);
69
    }
70

71
    std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(),
72
                      std::greater<std::pair<float, int> >());
73

74
    // print topk and score
75
    for (int i = 0; i < topk; i++)
76
    {
77
        float score = vec[i].first;
78
        int index = vec[i].second;
79
        fprintf(stderr, "%d = %f\n", index, score);
80
    }
81

82
    return 0;
83
}
84

85
int main(int argc, char** argv)
86
{
87
    if (argc != 2)
88
    {
89
        fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
90
        return -1;
91
    }
92

93
    const char* imagepath = argv[1];
94

95
    cv::Mat m = cv::imread(imagepath, 1);
96
    if (m.empty())
97
    {
98
        fprintf(stderr, "cv::imread %s failed\n", imagepath);
99
        return -1;
100
    }
101

102
    std::vector<float> cls_scores;
103
    detect_squeezenet(m, cls_scores);
104

105
    print_topk(cls_scores, 3);
106

107
    return 0;
108
}
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.