pytorch

Форк
0
/
compare_models_torch.cc 
326 строк · 10.0 Кб
1
/**
2
 * Copyright (c) 2016-present, Facebook, Inc.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
#include <iomanip>
18
#include <string>
19
#include <vector>
20

21
#include <ATen/ATen.h>
22
#include <caffe2/core/timer.h>
23
#include <caffe2/utils/string_utils.h>
24
#include <torch/csrc/autograd/grad_mode.h>
25
#include <torch/csrc/jit/serialization/import.h>
26
#include <torch/script.h>
27

28
#include <c10/mobile/CPUCachingAllocator.h>
29

30
C10_DEFINE_string(
31
    refmodel,
32
    "",
33
    "The reference torch script model to compare against.");
34
C10_DEFINE_string(
35
    model,
36
    "",
37
    "The torch script model to compare to the reference model.");
38
C10_DEFINE_string(
39
    input_dims,
40
    "",
41
    "Alternate to input_files, if all inputs are simple "
42
    "float TensorCPUs, specify the dimension using comma "
43
    "separated numbers. If multiple input needed, use "
44
    "semicolon to separate the dimension of different "
45
    "tensors.");
46
C10_DEFINE_string(input_type, "", "Input type (uint8_t/float)");
47
C10_DEFINE_string(
48
    input_memory_format,
49
    "contiguous_format",
50
    "Input memory format (contiguous_format/channels_last)");
51
C10_DEFINE_int(input_max, 1, "The maximum value inputs should have");
52
C10_DEFINE_int(input_min, -1, "The minimum value inputs should have");
53
C10_DEFINE_bool(
54
    no_inputs,
55
    false,
56
    "Whether the model has any input. Will ignore other input arguments if true");
57
C10_DEFINE_bool(
58
    use_caching_allocator,
59
    false,
60
    "Whether to cache allocations between inference iterations");
61
C10_DEFINE_bool(
62
    print_output,
63
    false,
64
    "Whether to print output with all one input tensor.");
65
C10_DEFINE_int(iter, 10, "The number of iterations to run.");
66
C10_DEFINE_int(report_freq, 1000, "An update will be reported every n iterations");
67
C10_DEFINE_int(pytext_len, 0, "Length of input sequence.");
68
C10_DEFINE_string(
69
    backend,
70
    "cpu",
71
    "what backend to use for model (vulkan, cpu, metal) (default=cpu)");
72
C10_DEFINE_string(
73
    refbackend,
74
    "cpu",
75
    "what backend to use for model (vulkan, cpu, metal) (default=cpu)");
76
C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison");
77
C10_DEFINE_int(nthreads, 1, "Number of threads to launch. Useful for checking correct concurrent behaviour.");
78
C10_DEFINE_bool(
79
    report_failures,
80
    true,
81
    "Whether to report error during failed iterations");
82

83
bool checkRtol(
84
    const at::Tensor& diff,
85
    const std::vector<at::Tensor>& inputs,
86
    float tolerance,
87
    bool report) {
88
  float maxValue = 0.0f;
89

90
  for (const auto& tensor : inputs) {
91
    maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
92
  }
93
  float threshold = tolerance * maxValue;
94
  float maxDiff = diff.abs().max().item<float>();
95

96
  bool passed = maxDiff < threshold;
97
  if (!passed && report) {
98
    std::cout << "Check FAILED!      Max diff allowed: "
99
              << std::setw(10) << std::setprecision(5) << threshold
100
              << "     max diff: "
101
              << std::setw(10) << std::setprecision(5) << maxDiff
102
              << std::endl;
103
  }
104

105
  return passed;
106
}
107

108
void report_pass_rate(int passed, int total) {
109
  int pass_rate = static_cast<int>(static_cast<float>(passed) / static_cast<float>(total) * 100);
110
  std::cout << "Output was equal within tolerance " << passed << "/"
111
            << total
112
            << " times. Pass rate: " << pass_rate
113
            << std::setprecision(2) << "%" << std::endl;
114
}
115

116
std::vector<std::string> split(
117
    char separator,
118
    const std::string& string,
119
    bool ignore_empty = true) {
120
  std::vector<std::string> pieces;
121
  std::stringstream ss(string);
122
  std::string item;
123
  while (getline(ss, item, separator)) {
124
    if (!ignore_empty || !item.empty()) {
125
      pieces.push_back(std::move(item));
126
    }
127
  }
128
  return pieces;
129
}
130

131
std::vector<c10::IValue> create_inputs(
132
    std::vector<c10::IValue>& refinputs,
133
    std::vector<c10::IValue>& inputs,
134
    std::string& refbackend,
135
    std::string& backend,
136
    const int range_min,
137
    const int range_max) {
138
  if (FLAGS_no_inputs) {
139
    return {};
140
  }
141

142
  CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified.");
143
  CAFFE_ENFORCE_GE(FLAGS_input_type.size(), 0, "Input type must be specified.");
144

145
  std::vector<std::string> input_dims_list = split(';', FLAGS_input_dims);
146
  std::vector<std::string> input_type_list = split(';', FLAGS_input_type);
147
  std::vector<std::string> input_memory_format_list =
148
      split(';', FLAGS_input_memory_format);
149

150
  CAFFE_ENFORCE_GE(
151
      input_dims_list.size(), 0, "Input dims not specified correctly.");
152
  CAFFE_ENFORCE_GE(
153
      input_type_list.size(), 0, "Input type not specified correctly.");
154
  CAFFE_ENFORCE_GE(
155
      input_memory_format_list.size(),
156
      0,
157
      "Input format list not specified correctly.");
158

159
  CAFFE_ENFORCE_EQ(
160
      input_dims_list.size(),
161
      input_type_list.size(),
162
      "Input dims and type should have the same number of items.");
163
  CAFFE_ENFORCE_EQ(
164
      input_dims_list.size(),
165
      input_memory_format_list.size(),
166
      "Input dims and format should have the same number of items.");
167

168
  for (size_t i = 0; i < input_dims_list.size(); ++i) {
169
    auto input_dims_str = split(',', input_dims_list[i]);
170
    std::vector<int64_t> input_dims;
171
    input_dims.reserve(input_dims_str.size());
172
    for (const auto& s : input_dims_str) {
173
      input_dims.push_back(std::stoi(s));
174
    }
175

176
    at::ScalarType input_type;
177
    if (input_type_list[i] == "float") {
178
      input_type = at::ScalarType::Float;
179
    } else if (input_type_list[i] == "uint8_t") {
180
      input_type = at::ScalarType::Byte;
181
    } else if (input_type_list[i] == "int64") {
182
      input_type = at::ScalarType::Long;
183
    } else {
184
      CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
185
    }
186

187
    at::MemoryFormat input_memory_format;
188
    if (input_memory_format_list[i] == "channels_last") {
189
      if (input_dims.size() != 4u) {
190
        CAFFE_THROW(
191
            "channels_last memory format only available on 4D tensors!");
192
      }
193
      input_memory_format = at::MemoryFormat::ChannelsLast;
194
    } else if (input_memory_format_list[i] == "contiguous_format") {
195
      input_memory_format = at::MemoryFormat::Contiguous;
196
    } else {
197
      CAFFE_THROW(
198
          "Unsupported input memory format: ", input_memory_format_list[i]);
199
    }
200

201
    const auto input_tensor = torch::rand(
202
        input_dims,
203
        at::TensorOptions(input_type).memory_format(input_memory_format))*(range_max - range_min) - range_min;
204

205
    if (refbackend == "vulkan") {
206
      refinputs.emplace_back(input_tensor.vulkan());
207
    } else {
208
      refinputs.emplace_back(input_tensor);
209
    }
210

211
    if (backend == "vulkan") {
212
      inputs.emplace_back(input_tensor.vulkan());
213
    } else {
214
      inputs.emplace_back(input_tensor);
215
    }
216
  }
217

218
  if (FLAGS_pytext_len > 0) {
219
    auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64);
220
    if (refbackend == "vulkan") {
221
      refinputs.emplace_back(stensor.vulkan());
222
    } else {
223
      refinputs.emplace_back(stensor);
224
    }
225

226
    if (backend == "vulkan") {
227
      inputs.emplace_back(stensor.vulkan());
228
    } else {
229
      inputs.emplace_back(stensor);
230
    }
231
  }
232

233
  return inputs;
234
}
235

236
void run_check(float tolerance) {
237
  torch::jit::Module module = torch::jit::load(FLAGS_model);
238
  torch::jit::Module refmodule = torch::jit::load(FLAGS_refmodel);
239

240
  module.eval();
241
  refmodule.eval();
242

243
  std::thread::id this_id = std::this_thread::get_id();
244
  std::cout << "Running check on thread " << this_id << "." << std::endl;
245

246
  int passed = 0;
247
  for (int i = 0; i < FLAGS_iter; ++i) {
248
    std::vector<c10::IValue> refinputs;
249
    std::vector<c10::IValue> inputs;
250
    create_inputs(
251
        refinputs, inputs,
252
        FLAGS_refbackend, FLAGS_backend,
253
        FLAGS_input_min, FLAGS_input_max);
254

255
    const auto refoutput = refmodule.forward(refinputs).toTensor().cpu();
256
    const auto output = module.forward(inputs).toTensor().cpu();
257

258
    bool check = checkRtol(
259
        refoutput-output,
260
        {refoutput, output},
261
        tolerance,
262
        FLAGS_report_failures);
263

264
    if (check) {
265
      passed += 1;
266
    }
267
    else if (FLAGS_report_failures) {
268
      std::cout << " (Iteration " << i << " failed)" << std::endl;
269
    }
270

271
    if (i > 0 && (i+1) % FLAGS_report_freq == 0) {
272
      report_pass_rate(passed, i+1);
273
    }
274
  }
275
  report_pass_rate(passed, FLAGS_iter);
276
}
277

278
int main(int argc, char** argv) {
279
  c10::SetUsageMessage(
280
      "Run accuracy comparison to a reference model for a pytorch model.\n"
281
      "Example usage:\n"
282
      "./compare_models_torch"
283
      " --refmodel=<ref_model_file>"
284
      " --model=<model_file>"
285
      " --iter=20");
286
  if (!c10::ParseCommandLineFlags(&argc, &argv)) {
287
    std::cerr << "Failed to parse command line flags!" << std::endl;
288
    return 1;
289
  }
290

291
  if (FLAGS_input_min >= FLAGS_input_max) {
292
    std::cerr << "Input min: " << FLAGS_input_min
293
              << " should be less than input max: "
294
              << FLAGS_input_max << std::endl;
295
    return 1;
296
  }
297

298
  std::stringstream ss(FLAGS_tolerance);
299
  float tolerance = 0;
300
  ss >> tolerance;
301
  std::cout << "tolerance: " << tolerance << std::endl;
302

303
  c10::InferenceMode mode;
304
  torch::autograd::AutoGradMode guard(false);
305
  torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
306

307
  c10::CPUCachingAllocator caching_allocator;
308
  c10::optional<c10::WithCPUCachingAllocatorGuard> caching_allocator_guard;
309
  if (FLAGS_use_caching_allocator) {
310
    caching_allocator_guard.emplace(&caching_allocator);
311
  }
312

313
  std::vector<std::thread> check_threads;
314
  check_threads.reserve(FLAGS_nthreads);
315
  for (int i = 0; i < FLAGS_nthreads; ++i) {
316
    check_threads.emplace_back(std::thread(run_check, tolerance));
317
  }
318

319
  for (std::thread& th : check_threads) {
320
    if (th.joinable()) {
321
      th.join();
322
    }
323
  }
324

325
  return 0;
326
}
327

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.