1
// This script converts the MNIST dataset to a lmdb (default) or
2
// leveldb (--backend=leveldb) format used by caffe to load data.
4
// convert_mnist_data [FLAGS] input_image_file input_label_file
6
// The MNIST dataset could be downloaded at
7
// http://yann.lecun.com/exdb/mnist/
9
#include <gflags/gflags.h>
10
#include <glog/logging.h>
11
#include <google/protobuf/text_format.h>
13
#if defined(USE_LEVELDB) && defined(USE_LMDB)
14
#include <leveldb/db.h>
15
#include <leveldb/write_batch.h>
22
#include <fstream> // NOLINT(readability/streams)
25
#include "boost/scoped_ptr.hpp"
26
#include "caffe/proto/caffe.pb.h"
27
#include "caffe/util/db.hpp"
28
#include "caffe/util/format.hpp"
30
#if defined(USE_LEVELDB) && defined(USE_LMDB)
32
using namespace caffe; // NOLINT(build/namespaces)
33
using boost::scoped_ptr;
36
DEFINE_string(backend, "lmdb", "The backend for storing the result");
38
uint32_t swap_endian(uint32_t val) {
39
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
40
return (val << 16) | (val >> 16);
43
void convert_dataset(const char* image_filename, const char* label_filename,
44
const char* db_path, const string& db_backend) {
46
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
47
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
48
CHECK(image_file) << "Unable to open file " << image_filename;
49
CHECK(label_file) << "Unable to open file " << label_filename;
50
// Read the magic and the meta data
57
image_file.read(reinterpret_cast<char*>(&magic), 4);
58
magic = swap_endian(magic);
59
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
60
label_file.read(reinterpret_cast<char*>(&magic), 4);
61
magic = swap_endian(magic);
62
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
63
image_file.read(reinterpret_cast<char*>(&num_items), 4);
64
num_items = swap_endian(num_items);
65
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
66
num_labels = swap_endian(num_labels);
67
CHECK_EQ(num_items, num_labels);
68
image_file.read(reinterpret_cast<char*>(&rows), 4);
69
rows = swap_endian(rows);
70
image_file.read(reinterpret_cast<char*>(&cols), 4);
71
cols = swap_endian(cols);
74
scoped_ptr<db::DB> db(db::GetDB(db_backend));
75
db->Open(db_path, db::NEW);
76
scoped_ptr<db::Transaction> txn(db->NewTransaction());
80
char* pixels = new char[rows * cols];
85
datum.set_channels(1);
86
datum.set_height(rows);
87
datum.set_width(cols);
88
LOG(INFO) << "A total of " << num_items << " items.";
89
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
90
for (int item_id = 0; item_id < num_items; ++item_id) {
91
image_file.read(pixels, rows * cols);
92
label_file.read(&label, 1);
93
datum.set_data(pixels, rows*cols);
94
datum.set_label(label);
95
string key_str = caffe::format_int(item_id, 8);
96
datum.SerializeToString(&value);
98
txn->Put(key_str, value);
100
if (++count % 1000 == 0) {
104
// write the last batch
105
if (count % 1000 != 0) {
108
LOG(INFO) << "Processed " << count << " files.";
113
int main(int argc, char** argv) {
114
#ifndef GFLAGS_GFLAGS_H_
115
namespace gflags = google;
118
FLAGS_alsologtostderr = 1;
120
gflags::SetUsageMessage("This script converts the MNIST dataset to\n"
121
"the lmdb/leveldb format used by Caffe to load data.\n"
123
" convert_mnist_data [FLAGS] input_image_file input_label_file "
125
"The MNIST dataset could be downloaded at\n"
126
" http://yann.lecun.com/exdb/mnist/\n"
127
"You should gunzip them after downloading,"
128
"or directly use data/mnist/get_mnist.sh\n");
129
gflags::ParseCommandLineFlags(&argc, &argv, true);
131
const string& db_backend = FLAGS_backend;
134
gflags::ShowUsageWithFlagsRestrict(argv[0],
135
"examples/mnist/convert_mnist_data");
137
google::InitGoogleLogging(argv[0]);
138
convert_dataset(argv[1], argv[2], argv[3], db_backend);
143
int main(int argc, char** argv) {
144
LOG(FATAL) << "This example requires LevelDB and LMDB; " <<
145
"compile with USE_LEVELDB and USE_LMDB.";
147
#endif // USE_LEVELDB and USE_LMDB