pytorch

Форк
0
/
onnx_exporter.cc 
1460 строк · 47.9 Кб
1
#include "caffe2/onnx/onnx_exporter.h"
2
#include "caffe2/core/logging.h"
3
#include "caffe2/core/memonger.h"
4
#include "caffe2/core/tensor_impl.h"
5
#include "caffe2/onnx/helper.h"
6
#include "caffe2/proto/caffe2_legacy.pb.h"
7
#include "caffe2/utils/map_utils.h"
8
#include "caffe2/utils/proto_utils.h"
9
#include "caffe2/utils/string_utils.h"
10

11
#include <numeric>
12
#include <unordered_set>
13

14
namespace caffe2 {
15
namespace onnx {
16

17
namespace {
18
// rewrite padding attributes
19
void ApplyTrans(
20
    std::unordered_map<std::string, AttributeProto>* attrs,
21
    bool global,
22
    const std::string& k,
23
    int dim = 2,
24
    const std::string& ks = "") {
25
  std::string ks2 = ks.empty() ? (k + "s") : ks;
26
  std::string k_h, k_w, k_t, k_l, k_b, k_r;
27
  if (dim == 2) {
28
    k_h = k + "_h";
29
    k_w = k + "_w";
30
  } else {
31
    k_t = k + "_t";
32
    k_l = k + "_l";
33
    k_b = k + "_b";
34
    k_r = k + "_r";
35
  }
36

37
  std::vector<int64_t> vals;
38
  if (dim == 2 && attrs->count(k_h) && attrs->count(k_w)) {
39
    auto it = attrs->find(k_h);
40
    vals.push_back(it->second.i());
41
    attrs->erase(it);
42
    it = attrs->find(k_w);
43
    vals.push_back(it->second.i());
44
    attrs->erase(it);
45
  } else if (
46
      dim == 4 && attrs->count(k_t) && attrs->count(k_b) && attrs->count(k_l) &&
47
      attrs->count(k_r)) {
48
    auto it = attrs->find(k_t);
49
    vals.push_back(it->second.i());
50
    attrs->erase(it);
51
    it = attrs->find(k_l);
52
    vals.push_back(it->second.i());
53
    attrs->erase(it);
54
    it = attrs->find(k_b);
55
    vals.push_back(it->second.i());
56
    attrs->erase(it);
57
    it = attrs->find(k_r);
58
    vals.push_back(it->second.i());
59
    attrs->erase(it);
60
  } else if (attrs->count(k)) {
61
    auto it = attrs->find(k);
62
    auto tmp = it->second.i();
63
    for (int i = 0; i < dim; ++i) {
64
      vals.push_back(tmp);
65
    }
66
    attrs->erase(it);
67
  }
68

69
  if (!vals.empty() && !global) {
70
    attrs->emplace(ks2, MakeAttribute(ks2, vals));
71
  }
72
}
73

74
int64_t DimProd(const caffe2::TensorShape& shape, int start, int end) {
75
  int64_t acc = 1;
76
  for (int i = start; i < end; ++i) {
77
    acc *= shape.dims(i);
78
  }
79
  return acc;
80
}
81

82
TensorProto CreateOnnxShapeTensor(
83
    std::shared_ptr<DummyName> dummy,
84
    const std::vector<int64_t>& shape) {
85
  TensorProto tensor;
86
  tensor.set_name(dummy->NewDummyName());
87
  tensor.set_data_type(TensorProto::INT64);
88
  tensor.add_dims(shape.size());
89
  tensor.mutable_raw_data()->assign(
90
      reinterpret_cast<const char*>(shape.data()),
91
      sizeof(int64_t) * shape.size());
92
  return tensor;
93
}
94

95
std::string SsaName(const std::string& n, int version) {
96
  return c10::str(n, "_", version);
97
}
98

99
NodeProto AddShapeNode(const std::string& input, const std::string& output) {
100
  NodeProto shape_node;
101
  shape_node.set_op_type("Shape");
102
  shape_node.add_input(input);
103
  shape_node.add_output(output);
104
  return shape_node;
105
}
106

107
void collectExternalsFromIfOpSubnet(
108
    const NetDef* net,
109
    std::vector<std::string>* input,
110
    std::vector<std::string>* output) {
111
  std::set<std::string> in_input, in_output;
112
  for (const auto& op : net->op()) {
113
    for (const auto& blob : op.input()) {
114
      in_input.emplace(blob);
115
    }
116
    for (const auto& blob : op.output()) {
117
      in_output.emplace(blob);
118
    }
119
  }
120

121
  for (const auto& blob : in_input) {
122
    if (!in_output.count(blob)) {
123
      input->push_back(blob);
124
    }
125
  }
126
  for (const auto& blob : in_output) {
127
    if (!in_input.count(blob)) {
128
      output->push_back(blob);
129
    }
130
  }
131
}
132

133
void ssaRewriteForIfOp(
134
    OperatorDef* op,
135
    std::unordered_map<std::string, int>* blob_versions,
136
    std::set<std::string>* is_initialized_tensor) {
137
  // Get all the "external" inputs and outputs of the subnet
138
  // Since then_net and else_net has same external input/output, we only collect
139
  // external input/output from one of its subnet And perform the rewrite to
140
  // both then_net and else_net
141
  std::vector<std::string> if_external_input;
142
  std::vector<std::string> if_external_output;
143

144
  std::unordered_set<std::string> if_inputs, if_outputs;
145
  for (const auto& input : op->input()) {
146
    if_inputs.insert(input);
147
  }
148
  for (const auto& output : op->output()) {
149
    if_outputs.insert(output);
150
  }
151

152
  ArgumentHelper helper(*op);
153
  Argument *then_arg = nullptr, *else_arg = nullptr;
154
  NetDef* target_net = nullptr;
155
  bool has_then = false, has_else = false;
156

157
  if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
158
    then_arg = GetMutableArgument("then_net", false, op);
159
    target_net = then_arg->mutable_n();
160
    has_then = true;
161
  }
162
  if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
163
    else_arg = GetMutableArgument("else_net", false, op);
164
    if (!has_then) {
165
      target_net = else_arg->mutable_n();
166
    }
167
    has_else = true;
168
  }
169

170
  if (has_then || has_else) {
171
    collectExternalsFromIfOpSubnet(
172
        target_net, &if_external_input, &if_external_output);
173

174
    // Add inputs/outputs of the sub_net to the inputs/outputs of the op
175
    for (const auto& input : if_external_input) {
176
      if (if_inputs.count(input) == 0) {
177
        op->add_input(input);
178
      }
179
    }
180
    for (const auto& output : if_external_output) {
181
      if (if_outputs.count(output) == 0) {
182
        op->add_output(output);
183
      }
184
    }
185
    std::map<string, string> oldname_to_newname;
186

187
    // Build oldname_to_newname map
188
    for (auto& input : if_external_input) {
189
      const auto it = blob_versions->find(input);
190
      if (it != blob_versions->end()) {
191
        oldname_to_newname[input] = SsaName(input, it->second);
192
      }
193
    }
194
    for (auto& output : if_external_output) {
195
      auto it = blob_versions->find(output);
196
      if (it != blob_versions->end()) {
197
        if (is_initialized_tensor->count(output) == 0) {
198
          it->second += 1;
199
        } else {
200
          is_initialized_tensor->erase(output);
201
        }
202
        oldname_to_newname[output] = SsaName(output, it->second);
203
      } else {
204
        blob_versions->emplace(output, 0);
205
        oldname_to_newname[output] = SsaName(output, 0);
206
      }
207
    }
208

209
    if (has_then) {
210
      rewriteSubnet(then_arg, oldname_to_newname);
211
    }
212
    if (has_else) {
213
      rewriteSubnet(else_arg, oldname_to_newname);
214
    }
215
  }
216
}
217

218
void revertRenamedExternalOutput(
219
    OperatorDef* op,
220
    const std::unordered_map<std::string, std::string>&
221
        renamed_external_outputs) {
222
  for (auto& input : *(op->mutable_input())) {
223
    const auto it = renamed_external_outputs.find(input);
224
    if (it != renamed_external_outputs.end()) {
225
      input = it->second;
226
    }
227
  }
228
  for (auto& output : *(op->mutable_output())) {
229
    const auto it = renamed_external_outputs.find(output);
230
    if (it != renamed_external_outputs.end()) {
231
      output = it->second;
232
    }
233
  }
234
}
235

236
void revertRenamedExternalOutputForIfOp(
237
    OperatorDef* if_op,
238
    const std::unordered_map<std::string, std::string>&
239
        renamed_external_outputs) {
240
  ArgumentHelper helper(*if_op);
241
  Argument *then_arg = nullptr, *else_arg = nullptr;
242

243
  revertRenamedExternalOutput(if_op, renamed_external_outputs);
244

245
  if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
246
    then_arg = GetMutableArgument("then_net", false, if_op);
247
    NetDef* net = then_arg->mutable_n();
248
    for (auto& op : *(net->mutable_op())) {
249
      revertRenamedExternalOutput(&op, renamed_external_outputs);
250
    }
251
  }
252
  if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
253
    else_arg = GetMutableArgument("else_net", false, if_op);
254
    NetDef* net = else_arg->mutable_n();
255
    for (auto& op : *(net->mutable_op())) {
256
      revertRenamedExternalOutput(&op, renamed_external_outputs);
257
    }
258
  }
259
}
260
} // namespace
261

262
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
263
    caffe2::TensorProto::DataType t) {
264
#define CAFFE2_TO_ONNX_TYPE(x)   \
265
  case (caffe2::TensorProto::x): \
266
    return ::ONNX_NAMESPACE::TensorProto::x
267
  switch (t) {
268
    CAFFE2_TO_ONNX_TYPE(FLOAT);
269
    CAFFE2_TO_ONNX_TYPE(BOOL);
270
    CAFFE2_TO_ONNX_TYPE(INT8);
271
    CAFFE2_TO_ONNX_TYPE(UINT8);
272
    CAFFE2_TO_ONNX_TYPE(UINT16);
273
    CAFFE2_TO_ONNX_TYPE(INT16);
274
    CAFFE2_TO_ONNX_TYPE(INT32);
275
    CAFFE2_TO_ONNX_TYPE(INT64);
276
    CAFFE2_TO_ONNX_TYPE(FLOAT16);
277
    default:
278
      LOG(WARNING) << "Unsupported Caffe2 tensor type: " << t
279
                   << ", fallback to FLOAT";
280
      return ::ONNX_NAMESPACE::TensorProto::FLOAT;
281
  }
282
#undef CAFFE2_TO_ONNX_TYPE
283
}
284

285
void rewriteSubnet(
286
    Argument* arg,
287
    std::map<std::string, std::string> oldname_to_newname) {
288
  NetDef* net = arg->mutable_n();
289
  // clear external inputs and outputs since they're no longer valid
290
  net->mutable_external_input()->Clear();
291
  net->mutable_external_output()->Clear();
292
  for (auto& op : *(net->mutable_op())) {
293
    for (auto& input : *(op.mutable_input())) {
294
      if (oldname_to_newname.find(input) != oldname_to_newname.end()) {
295
        input = oldname_to_newname[input];
296
      }
297
    }
298
    for (auto& output : *(op.mutable_output())) {
299
      if (oldname_to_newname.find(output) != oldname_to_newname.end()) {
300
        output = oldname_to_newname[output];
301
      }
302
    }
303
  }
304
}
305

306
std::unordered_map<std::string, std::string> SsaRewrite(
307
    caffe2::NetDef* init_net,
308
    caffe2::NetDef* pred_net,
309
    bool PreserveInPlaceOps) {
310
  std::unordered_map<std::string, std::string> input_mapping;
311
  std::unordered_map<std::string, int> blob_versions;
312

313
  if (init_net) {
314
    // No ssa rewrite is done for init net. The reason being that the output
315
    // blobs of init net are what becomes the input blobs of pred_net. Since
316
    // inputs of pred_net are not renamed we are not renaming the output of
317
    // init_net. Furthermore, the assumption made is that init_net is simple net
318
    // with each operator producing the one output and thus not renaming
319
    // translates to not renaming the outputs of the init_net. Create identical
320
    // mapping for now. This shall be removed eventually.
321
    for (const auto& name : init_net->external_input()) {
322
      input_mapping.emplace(name, name);
323
    }
324
    blob_versions.clear();
325
  }
326

327
  std::set<std::string> is_initialized_tensor;
328
  if (pred_net) {
329
    // Ssa rewriting modifies the net, check if the net passes schema check
330
    run_schema_check(*pred_net);
331

332
    std::unordered_set<std::string> external_outputs;
333
    for (const auto& input : pred_net->external_input()) {
334
      // Create identical mapping for now. This shall be removed eventually.
335
      input_mapping.emplace(input, input);
336
    }
337
    for (const auto& output : pred_net->external_output()) {
338
      external_outputs.emplace(output);
339
    }
340
    for (auto& op : *pred_net->mutable_op()) {
341
      // Special SSA Rewrite for subnet of If Operator
342
      // This needs to happen first because the inputs/outputs of If/AsyncIf
343
      // may get modified inside ssaRewriteForIfOp
344
      if (op.type() == "If" || op.type() == "AsyncIf") {
345
        ssaRewriteForIfOp(&op, &blob_versions, &is_initialized_tensor);
346
      }
347

348
      for (auto& input : *op.mutable_input()) {
349
        const auto it = blob_versions.find(input);
350
        if (it != blob_versions.end()) {
351
          input = SsaName(input, it->second);
352
        } else {
353
          // Input blob is not versioned yet.
354
          // If it is not versioned yet, it is assumed to be primary input,
355
          // Thus skip renaming it.
356
          continue;
357
        }
358
      }
359

360
      for (int out_idx = 0; out_idx < op.output_size(); out_idx++) {
361
        auto& output = *op.mutable_output(out_idx);
362

363
        // restore in-place settings
364
        bool is_inplace = false;
365
        if (PreserveInPlaceOps) {
366
          for (int in_idx = 0; in_idx < op.input_size(); in_idx++) {
367
            auto* schema = OpSchemaRegistry::Schema(op.type());
368
            if (schema && schema->inplace_enforced(in_idx, out_idx)) {
369
              output = op.input(in_idx);
370
              is_inplace = true;
371
              break;
372
            }
373
          }
374
        }
375
        if (is_inplace) {
376
          continue;
377
        }
378

379
        auto it = blob_versions.find(output);
380
        if (it != blob_versions.end()) {
381
          if (op.type() != "If" && op.type() != "AsyncIf") {
382
            if (is_initialized_tensor.count(output) == 0) {
383
              it->second += 1;
384
            } else {
385
              is_initialized_tensor.erase(output);
386
            }
387
          }
388
          output = SsaName(output, it->second);
389

390
        } else {
391
          blob_versions.emplace(output, 0);
392
          // These filling ops are designed for a by-default value for the
393
          // tensors generated by ops like If. For example, if an If op's
394
          // condition is not satisfied, and it does not have else_net, then it
395
          // will not generate any output blob, which may cause some error in
396
          // the future. Here we would like to ensure these tensors only been
397
          // ssa re-write once but not twice. (One in the filling operator, one
398
          // in If op)
399
          if ((caffe2::StartsWith(op.type(), "GivenTensor") &&
400
               caffe2::EndsWith(op.type(), "Fill")) ||
401
              op.type() == "ConstantFill" ||
402
              op.type() == "Int8GivenTensorFill" ||
403
              op.type() == "Int8GivenIntTensorFill") {
404
            is_initialized_tensor.insert(output);
405
          }
406
          output = SsaName(output, 0);
407
        }
408
      }
409
    }
410

411
    // For all the renamed blobs find if the blob is one of the external
412
    // output. If so add a mapping from it's latest renamed version to its
413
    // original name.
414
    std::unordered_map<std::string, std::string> renamed_external_outputs;
415
    for (const auto& it : blob_versions) {
416
      if (external_outputs.count(it.first)) {
417
        renamed_external_outputs.emplace(
418
            SsaName(it.first, it.second), it.first);
419
      }
420
    }
421

422
    // Use the mapping to find if the input or output of an op was a renamed
423
    // external output. If so replace it with its original name.
424
    for (auto& op : *pred_net->mutable_op()) {
425
      // If/AsyncIf needs special handling
426
      if (op.type() == "If" || op.type() == "AsyncIf") {
427
        revertRenamedExternalOutputForIfOp(&op, renamed_external_outputs);
428
      } else {
429
        revertRenamedExternalOutput(&op, renamed_external_outputs);
430
      }
431
    }
432
  }
433
  // run schema check again
434
  // NOLINTNEXTLINE(clang-analyzer-core.NonNullParamChecker)
435
  run_schema_check(*pred_net);
436

437
  return input_mapping;
438
}
439

440
const std::unordered_map<std::string, std::string>&
441
OnnxExporter::get_renamed_operators() const {
442
  const static std::unordered_map<std::string, std::string> kRenamedOperators{
443
      {"SpatialBN", "BatchNormalization"},
444
      {"Conv1D", "Conv"},
445
      {"Conv2D", "Conv"},
446
      {"Conv3D", "Conv"},
447
      {"ConvTranspose1D", "ConvTranspose"},
448
      {"ConvTranspose2D", "ConvTranspose"},
449
      {"ConvTranspose3D", "ConvTranspose"},
450
      {"MaxPool1D", "MaxPool"},
451
      {"MaxPool2D", "MaxPool"},
452
      {"MaxPool3D", "MaxPool"},
453
      {"AveragePool1D", "AveragePool"},
454
      {"AveragePool2D", "AveragePool"},
455
      {"AveragePool3D", "AveragePool"},
456
      {"Copy", "Identity"}};
457
  return kRenamedOperators;
458
}
459

460
const std::unordered_map<std::string, std::string>&
461
OnnxExporter::get_renamed_attrs() const {
462
  const static std::unordered_map<std::string, std::string> kRenamedAttrs{
463
      {"kernels", "kernel_shape"}};
464
  return kRenamedAttrs;
465
}
466

467
const std::
468
    unordered_map<std::string, std::unordered_map<std::string, std::string>>&
469
    OnnxExporter::get_per_op_renamed_attrs() const {
470
  const static std::
471
      unordered_map<std::string, std::unordered_map<std::string, std::string>>
472
          kPerOpRenamedAttrs = {
473
              {"Squeeze", {{"dims", "axes"}}},
474
              {"Unsqueeze", {{"dims", "axes"}}},
475
              {"Transpose", {{"axes", "perm"}}},
476
              {"ConvTranspose", {{"adjs", "output_padding"}}},
477
              {"Selu", {{"scale", "gamma"}}}};
478

479
  return kPerOpRenamedAttrs;
480
}
481

482
const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
483
OnnxExporter::get_special_operators() const {
484
  const static std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>
485
      kSpecialOperators = {
486
          {"ArgMax", &OnnxExporter::CreateArgMaxMinOpNodes},
487
          {"ArgMin", &OnnxExporter::CreateArgMaxMinOpNodes},
488
          {"Add", &OnnxExporter::CreateBinaryElementwiseOpNodes},
489
          {"Sub", &OnnxExporter::CreateBinaryElementwiseOpNodes},
490
          {"Mul", &OnnxExporter::CreateBinaryElementwiseOpNodes},
491
          {"Div", &OnnxExporter::CreateBinaryElementwiseOpNodes},
492
          {"Pow", &OnnxExporter::CreateBinaryElementwiseOpNodes},
493
          {"And", &OnnxExporter::CreateBinaryElementwiseOpNodes},
494
          {"Or", &OnnxExporter::CreateBinaryElementwiseOpNodes},
495
          {"Xor", &OnnxExporter::CreateBinaryElementwiseOpNodes},
496
          {"Equal", &OnnxExporter::CreateBinaryElementwiseOpNodes},
497
          {"Greater", &OnnxExporter::CreateBinaryElementwiseOpNodes},
498
          {"Less", &OnnxExporter::CreateBinaryElementwiseOpNodes},
499
          {"Cast", &OnnxExporter::CreateCastNodes},
500
          {"ElementwiseLinear", &OnnxExporter::CreateElementwiseLinearNodes},
501
          {"Conv", &OnnxExporter::CreateConvPoolNodes},
502
          {"ConvTranspose", &OnnxExporter::CreateConvPoolNodes},
503
          {"MaxPool", &OnnxExporter::CreateConvPoolNodes},
504
          {"AveragePool", &OnnxExporter::CreateConvPoolNodes},
505
          {"FC", &OnnxExporter::CreateGemmNodes},
506
          {"Concat", &OnnxExporter::CreateConcatNodes},
507
          {"MergeDim", &OnnxExporter::CreateMergeDimNodes},
508
          {"LRN", &OnnxExporter::CreateLrnNodes},
509
          {"Reshape", &OnnxExporter::CreateReshapeNodes},
510
          {"Slice", &OnnxExporter::CreateSliceNodes},
511
          {"ChannelShuffle", &OnnxExporter::CreateChannelShuffleNodes},
512
          {"ReduceMean", &OnnxExporter::CreateReduceMeanNodes},
513
          {"ReduceFrontMean", &OnnxExporter::CreateReduceMeanNodes},
514
          {"ReduceBackMean", &OnnxExporter::CreateReduceMeanNodes},
515
          {"ResizeNearest", &OnnxExporter::CreateUpsampleNodes}};
516
  return kSpecialOperators;
517
}
518

519
void OnnxExporter::CopyCaffe2ArgToOnnxAttr(
520
    AttributeProto* attr,
521
    const std::string& op_type,
522
    const caffe2::Argument& arg) {
523
  std::string name =
524
      caffe2::get_default(get_renamed_attrs(), arg.name(), arg.name());
525
  const auto& per_op_renamed_attr_lut = get_per_op_renamed_attrs();
526
  const auto it = per_op_renamed_attr_lut.find(op_type);
527
  if (it != per_op_renamed_attr_lut.end()) {
528
    // Per-op attribute renames override the global attribute renames
529
    name = caffe2::get_default(it->second, arg.name(), name);
530
  }
531
  attr->set_name(name);
532

533
  if (arg.has_f()) {
534
    attr->set_f(arg.f());
535
    attr->set_type(AttributeProto::FLOAT);
536
  } else if (arg.has_i()) {
537
    attr->set_i(arg.i());
538
    attr->set_type(AttributeProto::INT);
539
  } else if (arg.has_s()) {
540
    attr->set_s(arg.s());
541
    attr->set_type(AttributeProto::STRING);
542
  } else if (arg.floats_size()) {
543
    attr->mutable_floats()->CopyFrom(arg.floats());
544
    attr->set_type(AttributeProto::STRINGS);
545
  } else if (arg.ints_size()) {
546
    attr->mutable_ints()->CopyFrom(arg.ints());
547
    attr->set_type(AttributeProto::INTS);
548
  } else if (arg.strings_size()) {
549
    attr->mutable_strings()->CopyFrom(arg.strings());
550
    attr->set_type(AttributeProto::STRINGS);
551
  } else {
552
    CAFFE_THROW(c10::str("Unsupported Caffe2 argument: ", arg.name()));
553
  }
554
}
555

556
bool OnnxExporter::IsBlockListed(const caffe2::Argument& arg) {
557
  const static std::unordered_map<std::string, std::unordered_set<std::string>>
558
      kBlockListString = {{"order", {"NCHW"}}};
559
  const static std::unordered_map<std::string, std::unordered_set<int64_t>>
560
      kBlockListInt = {
561
          {"cudnn_exhaustive_search", {0, 1}},
562
          {"use_cudnn", {0, 1}},
563
          {"exhaustive_search", {0, 1}},
564
          {"is_test", {0, 1}},
565
          {"broadcast", {0, 1}}};
566

567
  if (arg.has_i()) {
568
    const auto it = kBlockListInt.find(arg.name());
569
    if (it != kBlockListInt.end()) {
570
      return it->second.count(arg.i());
571
    }
572
  } else if (arg.has_s()) {
573
    const auto it = kBlockListString.find(arg.name());
574
    if (it != kBlockListString.end()) {
575
      return it->second.count(arg.s());
576
    }
577
  }
578

579
  return false;
580
}
581

582
ConvertedResult OnnxExporter::Caffe2OpToOnnxNodes(
583
    const caffe2::OperatorDef& def,
584
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
585
  std::string type = def.type();
586
  const auto& renamed_op_lut = get_renamed_operators();
587
  const auto it = renamed_op_lut.find(type);
588
  if (it != renamed_op_lut.end()) {
589
    type = it->second;
590
  }
591
  const auto& special_op_lut = get_special_operators();
592
  const auto it_op = get_special_operators().find(type);
593
  if (it_op != special_op_lut.end()) {
594
    return (this->*(it_op->second))(def, shapes);
595
  } else {
596
    return CommonCaffe2OpToOnnxNodes(def);
597
  }
598
}
599

600
ConvertedResult OnnxExporter::CommonCaffe2OpToOnnxNodes(
601
    const caffe2::OperatorDef& def) {
602
  ConvertedResult result;
603
  auto& nodes = result.first;
604
  nodes.emplace_back();
605
  NodeProto& node = nodes.back();
606
  node.set_name(def.name());
607
  node.set_op_type(
608
      caffe2::get_default(get_renamed_operators(), def.type(), def.type()));
609
  for (const auto& i : def.input()) {
610
    node.add_input(i);
611
  }
612
  for (const auto& o : def.output()) {
613
    node.add_output(o);
614
  }
615
  for (const auto& a : def.arg()) {
616
    if (!IsBlockListed(a)) {
617
      auto* attr = node.add_attribute();
618
      CopyCaffe2ArgToOnnxAttr(attr, def.type(), a);
619
    }
620
  }
621
  return result;
622
}
623

624
ConvertedResult OnnxExporter::CreateArgMaxMinOpNodes(
625
    const caffe2::OperatorDef& def,
626
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
627
  auto result = CommonCaffe2OpToOnnxNodes(def);
628
  auto& nodes = result.first;
629

630
  CAFFE_ENFORCE_EQ(nodes.size(), 1);
631
  auto& node = nodes.back();
632

633
  if (!ArgumentHelper::HasArgument(def, "axis")) {
634
    const auto& x = def.input(0);
635
    const auto& x_shape = shapes.at(x);
636
    node.add_attribute()->CopyFrom(
637
        MakeAttribute("axis", x_shape.dims().size() - 1));
638
  }
639

640
  return result;
641
}
642

643
ConvertedResult OnnxExporter::CreateBinaryElementwiseOpNodes(
644
    const caffe2::OperatorDef& def,
645
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
646
  caffe2::OperatorDef mdef(def); // The modified def without broadcast and axis
647
  const auto& x = mdef.input(0);
648
  const auto& y = def.input(1); // Refer to the old def, later won't change it.
649
  const auto& x_shape = shapes.at(x);
650
  const auto& y_shape = shapes.at(y);
651
  for (int i = 0; i < mdef.arg_size(); ++i) {
652
    const auto& arg = mdef.arg(i);
653
    if (arg.name() == "broadcast") {
654
      ArgumentHelper::RemoveArgument(mdef, i);
655
      break;
656
    }
657
  }
658
  std::vector<int64_t> axes;
659
  for (int i = 0; i < mdef.arg_size(); ++i) {
660
    const auto& arg = mdef.arg(i);
661
    if (arg.name() == "axis") {
662
      int64_t axis = arg.i();
663
      if (x_shape.dims().size() - axis != y_shape.dims().size()) {
664
        // The upper bound (excluded) of expanded y.
665
        int64_t end_dim =
666
            y_shape.dims().size() - 1 - axis + x_shape.dims().size();
667
        axes.resize(end_dim - y_shape.dims().size());
668
        std::iota(axes.begin(), axes.end(), y_shape.dims().size());
669
        mdef.set_input(1, dummy_->NewDummyName());
670
      }
671
      ArgumentHelper::RemoveArgument(mdef, i);
672
      break;
673
    }
674
  }
675

676
  auto result = CommonCaffe2OpToOnnxNodes(mdef);
677
  if (axes.size() != 0) {
678
    result.first.insert(
679
        result.first.begin(),
680
        MakeNode(
681
            "Unsqueeze", {y}, {mdef.input(1)}, {MakeAttribute("axes", axes)}));
682
  }
683
  return result;
684
}
685

686
ConvertedResult OnnxExporter::CreateCastNodes(
687
    const caffe2::OperatorDef& def,
688
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
689
  auto result = CommonCaffe2OpToOnnxNodes(def);
690
  auto* attr = result.first[0].mutable_attribute(0);
691
  auto onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UNDEFINED;
692
  const auto& arg = def.arg(0);
693
  if (arg.has_s()) {
694
    auto c2_dtype = arg.s();
695
    std::transform(
696
        c2_dtype.begin(), c2_dtype.end(), c2_dtype.begin(), ::toupper);
697
    if (c2_dtype == "FLOAT") {
698
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT;
699
    } else if (c2_dtype == "INT32") {
700
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT32;
701
    } else if (c2_dtype == "BOOL") {
702
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::BOOL;
703
    } else if (c2_dtype == "UINT8") {
704
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT8;
705
    } else if (c2_dtype == "INT8") {
706
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT8;
707
    } else if (c2_dtype == "UINT16") {
708
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT16;
709
    } else if (c2_dtype == "INT16") {
710
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT16;
711
    } else if (c2_dtype == "INT64") {
712
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT64;
713
    } else if (c2_dtype == "FLOAT16") {
714
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT16;
715
    } else if (c2_dtype == "DOUBLE") {
716
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::DOUBLE;
717
    } else {
718
      onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UNDEFINED;
719
    }
720
    CAFFE_ENFORCE_NE(
721
        onnx_dtype,
722
        ::ONNX_NAMESPACE::TensorProto::UNDEFINED,
723
        "Casting to '",
724
        c2_dtype,
725
        "' dtype is not supported");
726
    attr->clear_s();
727
    attr->set_type(AttributeProto::INT);
728
  } else if (arg.has_i()) {
729
    const auto& c2_dtype = arg.i();
730
    switch (c2_dtype) {
731
      case caffe2::TensorProto::FLOAT:
732
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT;
733
        break;
734
      case caffe2::TensorProto::INT32:
735
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT32;
736
        break;
737
      case caffe2::TensorProto::BOOL:
738
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::BOOL;
739
        break;
740
      case caffe2::TensorProto::UINT8:
741
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT8;
742
        break;
743
      case caffe2::TensorProto::INT8:
744
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT8;
745
        break;
746
      case caffe2::TensorProto::UINT16:
747
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT16;
748
        break;
749
      case caffe2::TensorProto::INT16:
750
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT16;
751
        break;
752
      case caffe2::TensorProto::INT64:
753
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT64;
754
        break;
755
      case caffe2::TensorProto::FLOAT16:
756
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT16;
757
        break;
758
      case caffe2::TensorProto::DOUBLE:
759
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::DOUBLE;
760
        break;
761

762
      case caffe2::TensorProto::STRING:
763
      case caffe2::TensorProto::BYTE:
764
      case caffe2::TensorProto::UNDEFINED:
765
        onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UNDEFINED;
766
        break;
767
    }
768
    CAFFE_ENFORCE_NE(
769
        onnx_dtype,
770
        ::ONNX_NAMESPACE::TensorProto::UNDEFINED,
771
        "Casting to '",
772
        c2_dtype,
773
        "' dtype is not supported");
774
  }
775
  attr->set_i(onnx_dtype);
776
  return result;
777
}
778

779
ConvertedResult OnnxExporter::CreateElementwiseLinearNodes(
780
    const caffe2::OperatorDef& def,
781
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
782
  CAFFE_ENFORCE_EQ(def.input_size(), 3);
783
  CAFFE_ENFORCE_GE(def.output_size(), 1);
784
  const auto& x = def.input(0);
785
  const auto& w = def.input(1);
786
  const auto& b = def.input(2);
787
  const auto& y = def.output(0);
788
  CAFFE_ENFORCE_EQ(shapes.at(w).dims().size(), 1);
789
  CAFFE_ENFORCE_EQ(shapes.at(b).dims().size(), 1);
790

791
  ConvertedResult result;
792
  auto& nodes = result.first;
793
  auto& const_tensors = result.second;
794
  std::unordered_map<std::string, const caffe2::Argument*> args;
795
  for (const auto& a : def.arg()) {
796
    args.emplace(a.name(), &a);
797
  }
798

799
  const auto& x_shape = shapes.at(x);
800
  const auto it = args.find("axis");
801
  const int64_t axis = it == args.end() ? 1 : it->second->i();
802
  const bool need_reshape = axis + 1 != x_shape.dims().size();
803

804
  auto fma_x_input = x;
805
  if (need_reshape) {
806
    const auto inner = DimProd(x_shape, axis, x_shape.dims().size());
807
    CAFFE_ENFORCE_EQ(shapes.at(w).dims(0), inner);
808
    CAFFE_ENFORCE_EQ(shapes.at(b).dims(0), inner);
809

810
    fma_x_input = dummy_->NewDummyName();
811
    const_tensors.emplace_back(CreateOnnxShapeTensor(
812
        dummy_, std::vector<int64_t>{-1, shapes.at(w).dims(0)}));
813
    nodes.emplace_back(
814
        MakeNode("Reshape", {x, const_tensors.back().name()}, {fma_x_input}));
815
  }
816

817
  const auto& mul_output = dummy_->NewDummyName();
818
  nodes.emplace_back(
819
      MakeNode("Mul", {fma_x_input, w}, {mul_output}, def.name()));
820

821
  const auto& fma_y_output = need_reshape ? dummy_->NewDummyName() : y;
822
  nodes.emplace_back(
823
      MakeNode("Add", {mul_output, b}, {fma_y_output}, def.name()));
824

825
  if (need_reshape) {
826
    const auto shape = dummy_->NewDummyName();
827
    nodes.emplace_back(MakeNode("Shape", {x}, {shape}));
828
    nodes.emplace_back(MakeNode("Reshape", {fma_y_output, shape}, {y}));
829
  }
830

831
  return result;
832
}
833

834
ConvertedResult OnnxExporter::CreateConvPoolNodes(
835
    const caffe2::OperatorDef& def,
836
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
837
  auto result = CommonCaffe2OpToOnnxNodes(def);
838
  auto& nodes = result.first;
839
  auto& node = nodes.back();
840

841
  std::unordered_map<std::string, AttributeProto> attrs;
842
  for (const auto& attr : node.attribute()) {
843
    attrs.emplace(attr.name(), attr);
844
  }
845

846
  // Handle global pooling
847
  bool global = false;
848
  if (node.op_type() == "MaxPool" || node.op_type() == "AveragePool") {
849
    auto it = attrs.find("global_pooling");
850
    if (it != attrs.end() && it->second.has_i() && it->second.i()) {
851
      node.set_op_type("Global" + node.op_type());
852
      global = true;
853
      attrs.erase(it);
854
    }
855
  }
856

857
  ApplyTrans(&attrs, global, "kernel", 2, "kernel_shape");
858
  ApplyTrans(&attrs, global, "stride");
859
  ApplyTrans(&attrs, global, "dilation");
860
  ApplyTrans(&attrs, global, "adj");
861
  ApplyTrans(&attrs, global, "pad", 4);
862

863
  // Fix legacy pad attr
864
  auto it = attrs.find("legacy_pad");
865
  if (it != attrs.end()) {
866
    auto legacy_pad_attr = it->second;
867
    attrs.erase(it);
868
    CAFFE_ENFORCE(
869
        node.op_type().size() >= 4 &&
870
        (node.op_type().rfind("Pool") == node.op_type().size() - 4));
871
    const auto& input_size = shapes.at(node.input(0));
872
    const auto& output_size = shapes.at(node.output(0));
873
    CAFFE_ENFORCE_EQ(output_size.dims().size(), 4);
874
    if (!global && // global pool does not care about legacy pad
875
        legacy_pad_attr.i() !=
876
            static_cast<int64_t>(caffe2::LegacyPadding::NOTSET)) {
877
      if (legacy_pad_attr.i() ==
878
          static_cast<int64_t>(caffe2::LegacyPadding::VALID)) {
879
        CAFFE_ENFORCE(!attrs.count("pads"));
880
        attrs.emplace("auto_pad", MakeAttribute("auto_pad", "VALID"));
881
      } else if (
882
          legacy_pad_attr.i() ==
883
          static_cast<int64_t>(caffe2::LegacyPadding::SAME)) {
884
        CAFFE_ENFORCE(!attrs.count("pads"));
885
        // default behavior in Caffe2 is SAME_UPPER
886
        // https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h#L39
887
        attrs.emplace("auto_pad", MakeAttribute("auto_pad", "SAME_UPPER"));
888
      } else if (
889
          legacy_pad_attr.i() ==
890
          static_cast<int64_t>(caffe2::LegacyPadding::CAFFE_LEGACY_POOLING)) {
891
        // The problem here is that, Pool op in Caffe may add an additional
892
        // pixel, if the last part is smaller than stride. So we use the
893
        // explicit padding to replace legacy_pad. pad[end] = output_size[start
894
        // + 2] * stride[start] - pad[start] - 1 + kernel[start] - input[start +
895
        // 2]; end = start + len(pad) / 2
896
        LOG(WARNING) << "Converting legacy padding to explicit padding.";
897
        auto* pads_attr = attrs.at("pads").mutable_ints();
898
        auto& strides_attr = attrs.at("strides").ints();
899
        auto& kernel_shape_attr = attrs.at("kernel_shape").ints();
900
        for (int i = 0; i < 2; ++i) {
901
          int64_t tmp_pad = output_size.dims(i + 2) * strides_attr.Get(i) -
902
              pads_attr->Get(i) - 1 + kernel_shape_attr.Get(i) -
903
              input_size.dims(i + 2);
904
          pads_attr->Set(i + 2, tmp_pad);
905
        }
906
      } else {
907
        LOG(ERROR) << "Don't know how to handle the legacy_pad:"
908
                   << legacy_pad_attr.i();
909
        CAFFE_THROW("Failed to handle legacy padding in pool operator!");
910
      }
911
    }
912
  }
913

914
  node.clear_attribute();
915
  for (const auto& kv : attrs) {
916
    auto* attr = node.add_attribute();
917
    attr->CopyFrom(kv.second);
918
  }
919

920
  return result;
921
}
922

923
ConvertedResult OnnxExporter::CreateLrnNodes(
924
    const caffe2::OperatorDef& def,
925
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
926
  auto result = CommonCaffe2OpToOnnxNodes(def);
927
  auto& nodes = result.first;
928

929
  CAFFE_ENFORCE_EQ(nodes.size(), 1);
930
  auto& node = nodes.back();
931
  if (node.output_size() == 2) {
932
    node.mutable_output()->RemoveLast();
933
  }
934

935
  return result;
936
}
937

938
ConvertedResult OnnxExporter::CreateConcatNodes(
939
    const caffe2::OperatorDef& def,
940
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
941
  caffe2::OperatorDef mdef(def); // The modified def without add_axis
942
  // In caffe2, we can optionally add an axis specified by `add_axis`
943
  int add_axis = 0;
944
  for (int i = 0; i < mdef.arg_size(); ++i) {
945
    const auto& arg = mdef.arg(i);
946
    if (arg.name() == "add_axis") {
947
      add_axis = arg.i();
948
      ArgumentHelper::RemoveArgument(mdef, i);
949
      break;
950
    }
951
  }
952

953
  auto result = CommonCaffe2OpToOnnxNodes(mdef);
954
  auto& nodes = result.first;
955
  nodes.reserve(nodes.size() + 3);
956
  auto& const_tensors = result.second;
957

958
  CAFFE_ENFORCE_EQ(nodes.size(), 1);
959
  auto& node = nodes.back();
960
  bool explicit_axis = false;
961
  int axis = -1;
962
  if (ArgumentHelper::HasArgument(mdef, "axis")) {
963
    axis = ArgumentHelper::GetSingleArgument(mdef, "axis", -1);
964
    explicit_axis = true;
965
  }
966
  if (!explicit_axis) {
967
    node.add_attribute()->CopyFrom(MakeAttribute("axis", 1));
968
  }
969

970
  // If we have add_axis, we need to add a reshape node
971
  auto final_output = node.output(0);
972
  if (add_axis > 0) {
973
    CAFFE_ENFORCE_GE(axis, 0);
974
    std::vector<int64_t> dims;
975
    const auto& shape0 = shapes.at(mdef.input(0));
976
    for (int i = 1; i < mdef.input_size(); ++i) {
977
      const auto& shape = shapes.at(mdef.input(i));
978
      CAFFE_ENFORCE_EQ(shape.dims(axis), shape0.dims(axis));
979
    }
980
    for (const auto d : shape0.dims()) {
981
      dims.push_back(d);
982
    }
983
    dims.insert(dims.begin() + axis, mdef.input_size());
984

985
    auto concat_output = dummy_->NewDummyName();
986
    *node.mutable_output(0) = concat_output;
987
    const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
988
    nodes.emplace_back(MakeNode(
989
        "Reshape",
990
        {concat_output, const_tensors.back().name()},
991
        {final_output}));
992
  }
993

994
  // If we have two output, we need to output the split_info, which can be
995
  // statically inferred from the input shapes
996
  if (node.output_size() == 2) {
997
    std::string second_output = node.output(1);
998
    node.mutable_output()->RemoveLast();
999
    std::vector<int32_t> split_info;
1000
    int adj_size = shapes.at(mdef.input(0)).dims_size() + (add_axis ? 1 : 0);
1001
    int canonical_axis = canonical_axis_index_(axis, adj_size);
1002
    CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
1003
    for (int i = 0; i < mdef.input_size(); ++i) {
1004
      // NOLINTNEXTLINE(performance-inefficient-vector-operation)
1005
      split_info.push_back(
1006
          add_axis ? 1 : shapes.at(mdef.input(i)).dims(canonical_axis));
1007
    }
1008
    auto split_info_tensor =
1009
        MakeTensor("split_info", split_info, TensorProto::INT32);
1010
    auto cnode = MakeNode("Constant", {}, {second_output});
1011
    cnode.add_attribute()->CopyFrom(MakeAttribute("value", split_info_tensor));
1012
    nodes.emplace_back(std::move(cnode));
1013
  }
1014
  return result;
1015
}
1016

1017
ConvertedResult OnnxExporter::CreateMergeDimNodes(
1018
    const caffe2::OperatorDef& def,
1019
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1020
  const auto& x = def.input(0);
1021
  const auto& y = def.output(0);
1022

1023
  ConvertedResult result;
1024
  auto& nodes = result.first;
1025
  auto& const_tensors = result.second;
1026

1027
  {
1028
    const auto ndim = shapes.at(x).dims().size();
1029
    CAFFE_ENFORCE_GE(ndim, 2, "No enough dims to merge.");
1030
    std::vector<int64_t> dims(ndim);
1031
    dims[0] = 1;
1032
    dims[1] = -1;
1033
    const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
1034
  }
1035

1036
  const auto reshaped = dummy_->NewDummyName();
1037
  nodes.emplace_back(
1038
      MakeNode("Reshape", {x, const_tensors.back().name()}, {reshaped}));
1039

1040
  nodes.emplace_back(MakeNode(
1041
      "Squeeze",
1042
      {reshaped},
1043
      {y},
1044
      std::vector<AttributeProto>{
1045
          MakeAttribute("axes", std::vector<int64_t>{0}),
1046
      }));
1047

1048
  return result;
1049
}
1050

1051
ConvertedResult OnnxExporter::CreateChannelShuffleNodes(
1052
    const caffe2::OperatorDef& def,
1053
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1054
  const auto& x = def.input(0);
1055
  const auto& y = def.output(0);
1056
  const auto& x_shape = shapes.at(x);
1057
  CAFFE_ENFORCE_EQ(
1058
      x_shape.dims().size(),
1059
      4,
1060
      "Input shape of ChannelShuffle needs to be in NCHW format");
1061
  auto n = x_shape.dims(0);
1062
  auto c = x_shape.dims(1);
1063
  auto h = x_shape.dims(2);
1064
  auto w = x_shape.dims(3);
1065
  int64_t g = 0;
1066
  for (const auto& arg : def.arg()) {
1067
    if (arg.name() == "group") {
1068
      g = arg.i();
1069
      break;
1070
    }
1071
  }
1072
  CAFFE_ENFORCE(g && c % g == 0);
1073
  ConvertedResult result;
1074
  auto& nodes = result.first;
1075
  auto& const_tensors = result.second;
1076

1077
  const auto reshape_output = dummy_->NewDummyName();
1078
  std::vector<int64_t> dims = {n, g, c / g, h, w};
1079
  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
1080
  nodes.emplace_back(
1081
      MakeNode("Reshape", {x, const_tensors.back().name()}, {reshape_output}));
1082

1083
  const auto transpose_output = dummy_->NewDummyName();
1084
  dims = {0, 2, 1, 3, 4};
1085
  nodes.emplace_back(MakeNode(
1086
      "Transpose",
1087
      {reshape_output},
1088
      {transpose_output},
1089
      {MakeAttribute("perm", dims)}));
1090

1091
  dims = {n, c, h, w};
1092
  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
1093
  nodes.emplace_back(MakeNode(
1094
      "Reshape", {transpose_output, const_tensors.back().name()}, {y}));
1095

1096
  return result;
1097
}
1098

1099
ConvertedResult OnnxExporter::CreateReduceMeanNodes(
1100
    const caffe2::OperatorDef& def,
1101
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1102
  CAFFE_ENFORCE_GE(def.input_size(), 1);
1103
  CAFFE_ENFORCE_LE(def.input_size(), 2);
1104
  CAFFE_ENFORCE_EQ(def.input_size(), 1, "Input \"lengths\" is not supported.");
1105
  CAFFE_ENFORCE_GE(def.output_size(), 1);
1106
  const auto& x = def.input(0);
1107
  const auto& y = def.output(0);
1108
  const auto& dims = shapes.at(x).dims();
1109

1110
  ConvertedResult result;
1111
  auto& nodes = result.first;
1112
  std::unordered_map<std::string, const caffe2::Argument*> args;
1113
  for (const auto& a : def.arg()) {
1114
    args.emplace(a.name(), &a);
1115
  }
1116

1117
  std::vector<int64_t> axes;
1118
  int64_t keepdims = 1;
1119

1120
  if (def.type() == "ReduceMean") {
1121
    // axes
1122
    auto it = args.find("axes");
1123
    if (it == args.end()) {
1124
      axes.resize(dims.size());
1125
      std::iota(axes.begin(), axes.end(), 0);
1126
    } else {
1127
      axes.assign(it->second->ints().begin(), it->second->ints().end());
1128
    }
1129

1130
    // keepdims
1131
    it = args.find("keepdims");
1132
    if (it != args.end()) {
1133
      keepdims = it->second->i();
1134
    }
1135
  } else {
1136
    // num_reduce_dim
1137
    auto it = args.find("num_reduce_dim");
1138
    const int64_t num_reduce_dim = it == args.end() ? 1 : it->second->i();
1139
    CAFFE_ENFORCE_LE(num_reduce_dim, dims.size());
1140
    axes.resize(num_reduce_dim);
1141

1142
    int64_t start_dim = 0;
1143
    if (def.type() == "ReduceFrontMean") {
1144
      start_dim = 0;
1145
    } else if (def.type() == "ReduceBackMean") {
1146
      start_dim = dims.size() - axes.size();
1147
    }
1148
    std::iota(axes.begin(), axes.end(), start_dim);
1149

1150
    keepdims = 0;
1151
  }
1152

1153
  nodes.emplace_back(MakeNode(
1154
      "ReduceMean",
1155
      {x},
1156
      {y},
1157
      {
1158
          MakeAttribute("axes", axes),
1159
          MakeAttribute("keepdims", keepdims),
1160
      },
1161
      def.name()));
1162

1163
  return result;
1164
}
1165

1166
ConvertedResult OnnxExporter::CreateUpsampleNodes(
1167
    const caffe2::OperatorDef& def,
1168
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1169
  ConvertedResult result;
1170
  //{H, W} => {1, 1, H, W}
1171
  auto& nodes = result.first;
1172
  auto resolved_scale = dummy_->NewDummyName();
1173
  if (def.input_size() == 1) {
1174
    float width_scale = 1.0;
1175
    float height_scale = 1.0;
1176
    for (const auto& a : def.arg()) {
1177
      if (a.name() == "width_scale") {
1178
        width_scale = a.f();
1179
      } else if (a.name() == "height_scale") {
1180
        height_scale = a.f();
1181
      }
1182
    }
1183
    CAFFE_ENFORCE_GT(width_scale, 0);
1184
    CAFFE_ENFORCE_GT(height_scale, 0);
1185
    std::vector<float> tmp_vector = {1, 1, height_scale, width_scale};
1186
    auto resolved_scale_tensor =
1187
        MakeTensor("resolved scale tensor", tmp_vector, TensorProto::FLOAT);
1188

1189
    auto node = MakeNode("Constant", {}, {resolved_scale});
1190
    node.add_attribute()->CopyFrom(
1191
        MakeAttribute("value", resolved_scale_tensor));
1192
    nodes.emplace_back(node);
1193

1194
  } else {
1195
    CAFFE_ENFORCE_EQ(def.input_size(), 2);
1196
    std::vector<float> tmp_vector = {1, 1};
1197
    auto scale_pads_tensor =
1198
        MakeTensor("scale pads", tmp_vector, TensorProto::FLOAT);
1199
    auto unresolved_scale_pads = dummy_->NewDummyName();
1200

1201
    auto node = MakeNode("Constant", {}, {unresolved_scale_pads});
1202
    node.add_attribute()->CopyFrom(MakeAttribute("value", scale_pads_tensor));
1203
    nodes.emplace_back(node);
1204

1205
    node = MakeNode(
1206
        "Concat", {unresolved_scale_pads, def.input(1)}, {resolved_scale});
1207
    node.add_attribute()->CopyFrom(MakeAttribute("axis", 0));
1208
    nodes.emplace_back(node);
1209
  }
1210
  std::vector<std::string> inputs = {def.input(0), resolved_scale};
1211
  std::vector<std::string> outputs(def.output().begin(), def.output().end());
1212
  auto node = MakeNode("Upsample", inputs, outputs, def.name());
1213
  node.add_attribute()->CopyFrom(MakeAttribute("mode", "nearest"));
1214
  nodes.emplace_back(node);
1215
  return result;
1216
}
1217

1218
ConvertedResult OnnxExporter::CreateSliceNodes(
1219
    const caffe2::OperatorDef& def,
1220
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1221
  CAFFE_ENFORCE_EQ(
1222
      def.input_size(),
1223
      1,
1224
      "ONNX Slice operator does not support dynamic slice.");
1225
  auto result = CommonCaffe2OpToOnnxNodes(def);
1226
  auto& nodes = result.first;
1227
  CAFFE_ENFORCE_EQ(nodes.size(), 1);
1228
  auto& node = nodes.back();
1229
  const auto& shape = shapes.at(node.input(0));
1230

1231
  std::vector<int64_t> dims;
1232
  for (auto& attr : *node.mutable_attribute()) {
1233
    if (attr.name() == "starts") {
1234
      auto len = attr.ints_size();
1235
      if (len) {
1236
        dims.resize(len);
1237
        std::iota(dims.begin(), dims.end(), 0);
1238
      }
1239
    } else if (attr.name() == "ends") {
1240
      for (int i = 0; i < attr.ints_size(); ++i) {
1241
        auto end = attr.ints(i);
1242
        if (end >= 0) {
1243
          continue;
1244
        }
1245
        if (end == -1) {
1246
          end = shape.dims(i);
1247
        } else {
1248
          ++end;
1249
        }
1250
        attr.set_ints(i, end);
1251
      }
1252
    }
1253
  }
1254
  if (!dims.empty()) {
1255
    node.add_attribute()->CopyFrom(MakeAttribute("axes", dims));
1256
  }
1257

1258
  return result;
1259
}
1260

1261
ConvertedResult OnnxExporter::CreateReshapeNodes(
1262
    const caffe2::OperatorDef& def,
1263
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1264
  auto result = CommonCaffe2OpToOnnxNodes(def);
1265
  auto& nodes = result.first;
1266
  auto& const_tensors = result.second;
1267
  CAFFE_ENFORCE_EQ(nodes.size(), 1);
1268
  auto& node = nodes.back();
1269

1270
  int i = 0;
1271
  int attr_size = node.attribute_size();
1272
  for (; i < attr_size; ++i) {
1273
    const auto& attr = node.attribute(i);
1274
    if (attr.name() == "shape") {
1275
      std::vector<int64_t> shape;
1276
      for (const auto k : attr.ints()) {
1277
        shape.push_back(k);
1278
      }
1279
      const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, shape));
1280
      node.add_input(const_tensors.back().name());
1281
      break;
1282
    }
1283
  }
1284
  if (i != attr_size) {
1285
    if (i != attr_size - 1) {
1286
      node.mutable_attribute()->SwapElements(i, attr_size - 1);
1287
    }
1288
    node.mutable_attribute()->RemoveLast();
1289
  }
1290

1291
  if (node.output_size() == 2) {
1292
    std::string shape_input = node.output(0);
1293
    std::string shape_output = node.output(1);
1294
    node.mutable_output()->RemoveLast();
1295
    nodes.emplace_back(AddShapeNode(shape_input, shape_output));
1296
  }
1297

1298
  return result;
1299
}
1300

1301
ConvertedResult OnnxExporter::CreateGemmNodes(
1302
    const caffe2::OperatorDef& def,
1303
    const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1304
  CAFFE_ENFORCE_EQ(def.input_size(), 3);
1305
  CAFFE_ENFORCE_GE(def.output_size(), 1);
1306
  // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
1307
  auto x = def.input(0);
1308
  auto w = def.input(1);
1309
  const auto& b = def.input(2);
1310
  const auto& y = def.output(0);
1311
  const auto& x_shape = shapes.at(x);
1312
  const auto& w_shape = shapes.at(w);
1313
  CAFFE_ENFORCE_GE(x_shape.dims().size(), 2);
1314
  CAFFE_ENFORCE_GE(w_shape.dims().size(), 2);
1315

1316
  ConvertedResult result;
1317
  auto& nodes = result.first;
1318
  auto& const_tensors = result.second;
1319
  std::unordered_map<std::string, const caffe2::Argument*> args;
1320
  for (const auto& a : def.arg()) {
1321
    args.emplace(a.name(), &a);
1322
  }
1323

1324
  auto it = args.find("axis");
1325
  int64_t axis = 1;
1326
  bool has_axis = (it != args.end());
1327
  if (has_axis) {
1328
    axis = it->second->i();
1329
  }
1330

1331
  auto gemm_x_input = x;
1332
  if (x_shape.dims().size() > 2) {
1333
    // we need to reshape only when dimension is higher than 2
1334
    const auto inner = DimProd(x_shape, axis, x_shape.dims().size());
1335

1336
    gemm_x_input = dummy_->NewDummyName();
1337
    const_tensors.emplace_back(
1338
        CreateOnnxShapeTensor(dummy_, std::vector<int64_t>{-1, inner}));
1339
    nodes.emplace_back(
1340
        MakeNode("Reshape", {x, const_tensors.back().name()}, {gemm_x_input}));
1341
  }
1342

1343
  it = args.find("axis_w");
1344
  int64_t axis_w = 1;
1345
  if (it != args.end()) {
1346
    axis_w = it->second->i();
1347
  }
1348
  if (w_shape.dims().size() > 2) {
1349
    // we need to reshape only when dimension is higher than 2
1350
    auto outer = DimProd(w_shape, 0, axis_w);
1351
    auto inner = DimProd(w_shape, axis_w, w_shape.dims().size());
1352
    auto reshaped_w = dummy_->NewDummyName();
1353
    const_tensors.emplace_back(
1354
        CreateOnnxShapeTensor(dummy_, std::vector<int64_t>{outer, inner}));
1355
    nodes.emplace_back(
1356
        MakeNode("Reshape", {w, const_tensors.back().name()}, {reshaped_w}));
1357
    w = reshaped_w;
1358
  }
1359

1360
  auto gemm_y_output = axis > 1 ? dummy_->NewDummyName() : y;
1361
  nodes.emplace_back(MakeNode(
1362
      "Gemm",
1363
      {gemm_x_input, w, b},
1364
      {gemm_y_output},
1365
      {MakeAttribute("transB", 1L)},
1366
      def.name()));
1367

1368
  // capture the outer shape if needed.
1369
  if (axis > 1) {
1370
    const auto x_shape_2 = dummy_->NewDummyName();
1371
    nodes.emplace_back(MakeNode("Shape", {x}, {x_shape_2}));
1372

1373
    const auto x_shape_outer = dummy_->NewDummyName();
1374
    nodes.emplace_back(MakeNode(
1375
        "Slice",
1376
        {x_shape_2},
1377
        {x_shape_outer},
1378
        std::vector<AttributeProto>{
1379
            MakeAttribute("starts", std::vector<int64_t>{0}),
1380
            MakeAttribute("ends", std::vector<int64_t>{axis}),
1381
        }));
1382

1383
    const auto y_shape = dummy_->NewDummyName();
1384
    const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, {-1}));
1385
    nodes.emplace_back(MakeNode(
1386
        "Concat",
1387
        {x_shape_outer, const_tensors.back().name()},
1388
        {y_shape},
1389
        std::vector<AttributeProto>{
1390
            MakeAttribute("axis", static_cast<int64_t>(0)),
1391
        }));
1392

1393
    nodes.emplace_back(MakeNode("Reshape", {gemm_y_output, y_shape}, {y}));
1394
  }
1395

1396
  return result;
1397
}
1398

1399
void OnnxExporter::InitOpToTensorProto(
1400
    const caffe2::OperatorDef& op,
1401
    TensorProto* tensor) {
1402
  CAFFE_ENFORCE_EQ(op.input_size(), 0);
1403
  CAFFE_ENFORCE_EQ(op.output_size(), 1);
1404

1405
  // Set name
1406
  tensor->set_name(op.output(0));
1407

1408
  const Argument* values = nullptr;
1409
  const Argument* shape = nullptr;
1410
  for (const auto& arg : op.arg()) {
1411
    if (arg.name() == "values") {
1412
      values = &arg;
1413
    } else if (arg.name() == "shape") {
1414
      shape = &arg;
1415
    }
1416
  }
1417

1418
  CAFFE_ENFORCE(values);
1419
  CAFFE_ENFORCE(shape);
1420

1421
  // Set dims
1422
  for (const auto i : shape->ints()) {
1423
    tensor->add_dims(i);
1424
  }
1425

1426
  // Set value
1427
  if (op.type() == "GivenTensorFill") {
1428
    tensor->set_data_type(TensorProto::FLOAT);
1429
    for (const auto i : values->floats()) {
1430
      tensor->add_float_data(i);
1431
    }
1432
  } else if (op.type() == "GivenTensorInt64Fill") {
1433
    tensor->set_data_type(TensorProto::INT64);
1434
    for (const auto i : values->ints()) {
1435
      tensor->add_int64_data(i);
1436
    }
1437
  } else if (op.type() == "GivenTensorIntFill") {
1438
    tensor->set_data_type(TensorProto::INT32);
1439
    for (const auto i : values->ints()) {
1440
      tensor->add_int32_data(i);
1441
    }
1442
  } else if (op.type() == "GivenTensorBoolFill") {
1443
    tensor->set_data_type(TensorProto::INT32);
1444
    for (const auto i : values->ints()) {
1445
      tensor->add_int32_data(i);
1446
    }
1447
  } else if (op.type() == "GivenTensorStringFill") {
1448
    tensor->set_data_type(TensorProto::STRING);
1449
    // TODO: we might need to do two pass to avoid adverse memory allocations
1450
    for (const auto& s : values->strings()) {
1451
      tensor->mutable_raw_data()->append(s);
1452
    }
1453
  } else {
1454
    CAFFE_THROW(
1455
        c10::str("Cannot convert C2 op ", op.type(), "to ONNX TensorProto"));
1456
  }
1457
}
1458

1459
} // namespace onnx
1460
} // namespace caffe2
1461

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.