pytorch

Форк
0
/
make_mnist_db.cc 
144 строки · 4.8 Кб
1
/**
2
 * Copyright (c) 2016-present, Facebook, Inc.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
// This script converts the MNIST dataset to leveldb.
18
// The MNIST dataset could be downloaded at
19
//    http://yann.lecun.com/exdb/mnist/
20

21
#include <fstream>  // NOLINT(readability/streams)
22
#include <string>
23

24
#include "caffe2/core/common.h"
25
#include "caffe2/core/db.h"
26
#include "caffe2/core/init.h"
27
#include "caffe2/proto/caffe2_pb.h"
28
#include "caffe2/core/logging.h"
29

30
C10_DEFINE_string(image_file, "", "The input image file name.");
31
C10_DEFINE_string(label_file, "", "The label file name.");
32
C10_DEFINE_string(output_file, "", "The output db name.");
33
C10_DEFINE_string(db, "leveldb", "The db type.");
34
C10_DEFINE_int(
35
    data_limit,
36
    -1,
37
    "If set, only output this number of data points.");
38
C10_DEFINE_bool(
39
    channel_first,
40
    false,
41
    "If set, write the data as channel-first (CHW order) as the old "
42
    "Caffe does.");
43

44
namespace caffe2 {
45
uint32_t swap_endian(uint32_t val) {
46
    val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
47
    return (val << 16) | (val >> 16);
48
}
49

50
void convert_dataset(const char* image_filename, const char* label_filename,
51
        const char* db_path, const int data_limit) {
52
  // Open files
53
  std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
54
  std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
55
  CAFFE_ENFORCE(image_file, "Unable to open file ", image_filename);
56
  CAFFE_ENFORCE(label_file, "Unable to open file ", label_filename);
57
  // Read the magic and the meta data
58
  uint32_t magic;
59
  uint32_t num_items;
60
  uint32_t num_labels;
61
  uint32_t rows;
62
  uint32_t cols;
63

64
  image_file.read(reinterpret_cast<char*>(&magic), 4);
65
  magic = swap_endian(magic);
66
  if (magic == 529205256) {
67
    LOG(FATAL) <<
68
        "It seems that you forgot to unzip the mnist dataset. You should "
69
        "first unzip them using e.g. gunzip on Linux.";
70
  }
71
  CAFFE_ENFORCE_EQ(magic, 2051, "Incorrect image file magic.");
72
  label_file.read(reinterpret_cast<char*>(&magic), 4);
73
  magic = swap_endian(magic);
74
  CAFFE_ENFORCE_EQ(magic, 2049, "Incorrect label file magic.");
75
  image_file.read(reinterpret_cast<char*>(&num_items), 4);
76
  num_items = swap_endian(num_items);
77
  label_file.read(reinterpret_cast<char*>(&num_labels), 4);
78
  num_labels = swap_endian(num_labels);
79
  CAFFE_ENFORCE_EQ(num_items, num_labels);
80
  image_file.read(reinterpret_cast<char*>(&rows), 4);
81
  rows = swap_endian(rows);
82
  image_file.read(reinterpret_cast<char*>(&cols), 4);
83
  cols = swap_endian(cols);
84

85
  // leveldb
86
  std::unique_ptr<db::DB> mnist_db(db::CreateDB(FLAGS_db, db_path, db::NEW));
87
  std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
88
  // Storing to db
89
  char label_value;
90
  std::vector<char> pixels(rows * cols);
91
  int count = 0;
92
  const int kMaxKeyLength = 11;
93
  char key_cstr[kMaxKeyLength];
94

95
  TensorProtos protos;
96
  TensorProto* data = protos.add_protos();
97
  TensorProto* label = protos.add_protos();
98
  data->set_data_type(TensorProto::BYTE);
99
  if (FLAGS_channel_first) {
100
    data->add_dims(1);
101
    data->add_dims(rows);
102
    data->add_dims(cols);
103
  } else {
104
    data->add_dims(rows);
105
    data->add_dims(cols);
106
    data->add_dims(1);
107
  }
108
  label->set_data_type(TensorProto::INT32);
109
  label->add_int32_data(0);
110

111
  LOG(INFO) << "A total of " << num_items << " items.";
112
  LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
113
  for (int item_id = 0; item_id < num_items; ++item_id) {
114
    image_file.read(pixels.data(), rows * cols);
115
    label_file.read(&label_value, 1);
116
    for (int i = 0; i < rows * cols; ++i) {
117
      data->set_byte_data(pixels.data(), rows * cols);
118
    }
119
    label->set_int32_data(0, static_cast<int>(label_value));
120
    snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
121
    string keystr(key_cstr);
122

123
    // Put in db
124
    transaction->Put(keystr, protos.SerializeAsString());
125
    if (++count % 1000 == 0) {
126
      transaction->Commit();
127
    }
128
    if (data_limit > 0 && count == data_limit) {
129
      LOG(INFO) << "Reached data limit of " << data_limit << ", stop.";
130
      break;
131
    }
132
  }
133
}
134
}  // namespace caffe2
135

136
int main(int argc, char** argv) {
137
  caffe2::GlobalInit(&argc, &argv);
138
  caffe2::convert_dataset(
139
      FLAGS_image_file.c_str(),
140
      FLAGS_label_file.c_str(),
141
      FLAGS_output_file.c_str(),
142
      FLAGS_data_limit);
143
  return 0;
144
}
145

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.