24
#include "caffe2/core/common.h"
25
#include "caffe2/core/db.h"
26
#include "caffe2/core/init.h"
27
#include "caffe2/proto/caffe2_pb.h"
28
#include "caffe2/core/logging.h"
30
C10_DEFINE_string(image_file, "", "The input image file name.");
31
C10_DEFINE_string(label_file, "", "The label file name.");
32
C10_DEFINE_string(output_file, "", "The output db name.");
33
C10_DEFINE_string(db, "leveldb", "The db type.");
37
"If set, only output this number of data points.");
41
"If set, write the data as channel-first (CHW order) as the old "
45
uint32_t swap_endian(uint32_t val) {
46
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
47
return (val << 16) | (val >> 16);
50
void convert_dataset(const char* image_filename, const char* label_filename,
51
const char* db_path, const int data_limit) {
53
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
54
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
55
CAFFE_ENFORCE(image_file, "Unable to open file ", image_filename);
56
CAFFE_ENFORCE(label_file, "Unable to open file ", label_filename);
64
image_file.read(reinterpret_cast<char*>(&magic), 4);
65
magic = swap_endian(magic);
66
if (magic == 529205256) {
68
"It seems that you forgot to unzip the mnist dataset. You should "
69
"first unzip them using e.g. gunzip on Linux.";
71
CAFFE_ENFORCE_EQ(magic, 2051, "Incorrect image file magic.");
72
label_file.read(reinterpret_cast<char*>(&magic), 4);
73
magic = swap_endian(magic);
74
CAFFE_ENFORCE_EQ(magic, 2049, "Incorrect label file magic.");
75
image_file.read(reinterpret_cast<char*>(&num_items), 4);
76
num_items = swap_endian(num_items);
77
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
78
num_labels = swap_endian(num_labels);
79
CAFFE_ENFORCE_EQ(num_items, num_labels);
80
image_file.read(reinterpret_cast<char*>(&rows), 4);
81
rows = swap_endian(rows);
82
image_file.read(reinterpret_cast<char*>(&cols), 4);
83
cols = swap_endian(cols);
86
std::unique_ptr<db::DB> mnist_db(db::CreateDB(FLAGS_db, db_path, db::NEW));
87
std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
90
std::vector<char> pixels(rows * cols);
92
const int kMaxKeyLength = 11;
93
char key_cstr[kMaxKeyLength];
96
TensorProto* data = protos.add_protos();
97
TensorProto* label = protos.add_protos();
98
data->set_data_type(TensorProto::BYTE);
99
if (FLAGS_channel_first) {
101
data->add_dims(rows);
102
data->add_dims(cols);
104
data->add_dims(rows);
105
data->add_dims(cols);
108
label->set_data_type(TensorProto::INT32);
109
label->add_int32_data(0);
111
LOG(INFO) << "A total of " << num_items << " items.";
112
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
113
for (int item_id = 0; item_id < num_items; ++item_id) {
114
image_file.read(pixels.data(), rows * cols);
115
label_file.read(&label_value, 1);
116
for (int i = 0; i < rows * cols; ++i) {
117
data->set_byte_data(pixels.data(), rows * cols);
119
label->set_int32_data(0, static_cast<int>(label_value));
120
snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
121
string keystr(key_cstr);
124
transaction->Put(keystr, protos.SerializeAsString());
125
if (++count % 1000 == 0) {
126
transaction->Commit();
128
if (data_limit > 0 && count == data_limit) {
129
LOG(INFO) << "Reached data limit of " << data_limit << ", stop.";
136
int main(int argc, char** argv) {
137
caffe2::GlobalInit(&argc, &argv);
138
caffe2::convert_dataset(
139
FLAGS_image_file.c_str(),
140
FLAGS_label_file.c_str(),
141
FLAGS_output_file.c_str(),