caffe

Форк
0
/
extract_features.cpp 
183 строки · 6.2 Кб
1
#include <string>
2
#include <vector>
3

4
#include "boost/algorithm/string.hpp"
5
#include "google/protobuf/text_format.h"
6

7
#include "caffe/blob.hpp"
8
#include "caffe/common.hpp"
9
#include "caffe/net.hpp"
10
#include "caffe/proto/caffe.pb.h"
11
#include "caffe/util/db.hpp"
12
#include "caffe/util/format.hpp"
13
#include "caffe/util/io.hpp"
14

15
using caffe::Blob;
16
using caffe::Caffe;
17
using caffe::Datum;
18
using caffe::Net;
19
using std::string;
20
namespace db = caffe::db;
21

22
template<typename Dtype>
23
int feature_extraction_pipeline(int argc, char** argv);
24

25
int main(int argc, char** argv) {
26
  return feature_extraction_pipeline<float>(argc, argv);
27
//  return feature_extraction_pipeline<double>(argc, argv);
28
}
29

30
template<typename Dtype>
31
int feature_extraction_pipeline(int argc, char** argv) {
32
  ::google::InitGoogleLogging(argv[0]);
33
  const int num_required_args = 7;
34
  if (argc < num_required_args) {
35
    LOG(ERROR)<<
36
    "This program takes in a trained network and an input data layer, and then"
37
    " extract features of the input data produced by the net.\n"
38
    "Usage: extract_features  pretrained_net_param"
39
    "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
40
    "  save_feature_dataset_name1[,name2,...]  num_mini_batches  db_type"
41
    "  [CPU/GPU] [DEVICE_ID=0]\n"
42
    "Note: you can extract multiple features in one pass by specifying"
43
    " multiple feature blob names and dataset names separated by ','."
44
    " The names cannot contain white space characters and the number of blobs"
45
    " and datasets must be equal.";
46
    return 1;
47
  }
48
  int arg_pos = num_required_args;
49

50
  arg_pos = num_required_args;
51
  if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
52
    LOG(ERROR)<< "Using GPU";
53
    int device_id = 0;
54
    if (argc > arg_pos + 1) {
55
      device_id = atoi(argv[arg_pos + 1]);
56
      CHECK_GE(device_id, 0);
57
    }
58
    LOG(ERROR) << "Using Device_id=" << device_id;
59
    Caffe::SetDevice(device_id);
60
    Caffe::set_mode(Caffe::GPU);
61
  } else {
62
    LOG(ERROR) << "Using CPU";
63
    Caffe::set_mode(Caffe::CPU);
64
  }
65

66
  arg_pos = 0;  // the name of the executable
67
  std::string pretrained_binary_proto(argv[++arg_pos]);
68

69
  // Expected prototxt contains at least one data layer such as
70
  //  the layer data_layer_name and one feature blob such as the
71
  //  fc7 top blob to extract features.
72
  /*
73
   layers {
74
     name: "data_layer_name"
75
     type: DATA
76
     data_param {
77
       source: "/path/to/your/images/to/extract/feature/images_leveldb"
78
       mean_file: "/path/to/your/image_mean.binaryproto"
79
       batch_size: 128
80
       crop_size: 227
81
       mirror: false
82
     }
83
     top: "data_blob_name"
84
     top: "label_blob_name"
85
   }
86
   layers {
87
     name: "drop7"
88
     type: DROPOUT
89
     dropout_param {
90
       dropout_ratio: 0.5
91
     }
92
     bottom: "fc7"
93
     top: "fc7"
94
   }
95
   */
96
  std::string feature_extraction_proto(argv[++arg_pos]);
97
  boost::shared_ptr<Net<Dtype> > feature_extraction_net(
98
      new Net<Dtype>(feature_extraction_proto, caffe::TEST));
99
  feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
100

101
  std::string extract_feature_blob_names(argv[++arg_pos]);
102
  std::vector<std::string> blob_names;
103
  boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
104

105
  std::string save_feature_dataset_names(argv[++arg_pos]);
106
  std::vector<std::string> dataset_names;
107
  boost::split(dataset_names, save_feature_dataset_names,
108
               boost::is_any_of(","));
109
  CHECK_EQ(blob_names.size(), dataset_names.size()) <<
110
      " the number of blob names and dataset names must be equal";
111
  size_t num_features = blob_names.size();
112

113
  for (size_t i = 0; i < num_features; i++) {
114
    CHECK(feature_extraction_net->has_blob(blob_names[i]))
115
        << "Unknown feature blob name " << blob_names[i]
116
        << " in the network " << feature_extraction_proto;
117
  }
118

119
  int num_mini_batches = atoi(argv[++arg_pos]);
120

121
  std::vector<boost::shared_ptr<db::DB> > feature_dbs;
122
  std::vector<boost::shared_ptr<db::Transaction> > txns;
123
  const char* db_type = argv[++arg_pos];
124
  for (size_t i = 0; i < num_features; ++i) {
125
    LOG(INFO)<< "Opening dataset " << dataset_names[i];
126
    boost::shared_ptr<db::DB> db(db::GetDB(db_type));
127
    db->Open(dataset_names.at(i), db::NEW);
128
    feature_dbs.push_back(db);
129
    boost::shared_ptr<db::Transaction> txn(db->NewTransaction());
130
    txns.push_back(txn);
131
  }
132

133
  LOG(ERROR)<< "Extracting Features";
134

135
  Datum datum;
136
  std::vector<int> image_indices(num_features, 0);
137
  for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
138
    feature_extraction_net->Forward();
139
    for (int i = 0; i < num_features; ++i) {
140
      const boost::shared_ptr<Blob<Dtype> > feature_blob =
141
        feature_extraction_net->blob_by_name(blob_names[i]);
142
      int batch_size = feature_blob->num();
143
      int dim_features = feature_blob->count() / batch_size;
144
      const Dtype* feature_blob_data;
145
      for (int n = 0; n < batch_size; ++n) {
146
        datum.set_height(feature_blob->height());
147
        datum.set_width(feature_blob->width());
148
        datum.set_channels(feature_blob->channels());
149
        datum.clear_data();
150
        datum.clear_float_data();
151
        feature_blob_data = feature_blob->cpu_data() +
152
            feature_blob->offset(n);
153
        for (int d = 0; d < dim_features; ++d) {
154
          datum.add_float_data(feature_blob_data[d]);
155
        }
156
        string key_str = caffe::format_int(image_indices[i], 10);
157

158
        string out;
159
        CHECK(datum.SerializeToString(&out));
160
        txns.at(i)->Put(key_str, out);
161
        ++image_indices[i];
162
        if (image_indices[i] % 1000 == 0) {
163
          txns.at(i)->Commit();
164
          txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
165
          LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
166
              " query images for feature blob " << blob_names[i];
167
        }
168
      }  // for (int n = 0; n < batch_size; ++n)
169
    }  // for (int i = 0; i < num_features; ++i)
170
  }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
171
  // write the last batch
172
  for (int i = 0; i < num_features; ++i) {
173
    if (image_indices[i] % 1000 != 0) {
174
      txns.at(i)->Commit();
175
    }
176
    LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
177
        " query images for feature blob " << blob_names[i];
178
    feature_dbs.at(i)->Close();
179
  }
180

181
  LOG(ERROR)<< "Successfully extracted the features!";
182
  return 0;
183
}
184

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.