pytorch

Форк
0
/
mpscnn_graph_mask.mm 
607 строк · 22.3 Кб
1
#include "mpscnn_graph_mask.h"
2
#include "caffe2/core/operator.h"
3
#include "mpscnn_context.h"
4

5
#import <Metal/Metal.h>
6
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
7
#import <UIKit/UIDevice.h>
8

9
namespace caffe2 {
10

11
namespace {
12
enum class StorageType {
13
  MPSTEMPORARYIMAGE, /* Default for MPSCNN */
14
  MPSIMAGE,
15
  CPU,
16
  INVALID
17
};
18

19
string asString(StorageType st) {
20
  switch (st) {
21
  case StorageType::MPSTEMPORARYIMAGE:
22
    return "MPSTEMPORARYIMAGE";
23
  case StorageType::MPSIMAGE:
24
    return "MPSIMAGE";
25
  case StorageType::CPU:
26
    return "CPU";
27
  case StorageType::INVALID:
28
    return "INVALID";
29
  }
30
}
31

32
bool isImage(StorageType type) {
33
  return type == StorageType::MPSTEMPORARYIMAGE ||
34
      type == StorageType::MPSIMAGE;
35
}
36

37
std::unordered_map<string, std::vector<StorageType>> inputStorageTypeMap = {
38
    {"MPSCNNGenerateProposalsCPP",
39
     std::vector<StorageType>{StorageType::CPU,
40
                              StorageType::CPU,
41
                              StorageType::CPU,
42
                              StorageType::CPU}},
43
    {"MPSCNNRoIWarp",
44
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
45
                              StorageType::CPU}},
46
    {"MPSCNNConvRelu",
47
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
48
                              StorageType::CPU,
49
                              StorageType::CPU}},
50
    {"MPSCNNFC",
51
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
52
                              StorageType::CPU,
53
                              StorageType::CPU}},
54
    {"MPSCNNConv",
55
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
56
                              StorageType::CPU,
57
                              StorageType::CPU}},
58
    {"MPSCNNConvTranspose",
59
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
60
                              StorageType::CPU,
61
                              StorageType::CPU}},
62
    {"MPSCNNMul",
63
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
64
                              StorageType::CPU}},
65
    {"MPSCNNSub",
66
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
67
                              StorageType::CPU}},
68
    {"MPSCNNNormalizePlanarYUV",
69
     std::vector<StorageType>{StorageType::MPSTEMPORARYIMAGE,
70
                              StorageType::CPU,
71
                              StorageType::CPU}}};
72
std::unordered_map<string, std::vector<StorageType>> outputStorageTypeMap = {
73
    {"MPSCNNGenerateProposalsCPP", std::vector<StorageType>{StorageType::CPU, StorageType::CPU}}};
74
std::vector<string> opsNeedsSync = {"MPSCNNGenerateProposalsCPP", "CopyFromMPSCNN", "CopyToMPSCNN"};
75

76
struct Analysis {
77
  struct SSA {
78
    using BlobVersions = std::unordered_map<std::string, size_t>;
79
    BlobVersions inVersions;
80
    BlobVersions outVersions;
81
  };
82
  struct BlobInfo {
83
    std::vector<size_t> inUsages; // ids for operator that used the blob
84
    StorageType storageType = StorageType::INVALID; // storage type of the blob
85
    int commandBufferId; // the id for command buffer used by the blob
86
  };
87
  std::vector<SSA> ssa;
88
  // blob name -> blob version -> blob information
89
  std::unordered_map<std::string, std::unordered_map<size_t, BlobInfo>> blobInfoMap;
90
  int currentCommandBufferId = 0;
91
};
92

93
void ssaAnalysis(Analysis& analysis, const NetDef& net) {
94
  Analysis::SSA::BlobVersions frontier;
95

96
  auto play = [&](size_t i, const OperatorDef& op) {
97
    Analysis::SSA::BlobVersions inVersions;
98
    for (const auto& s : op.input()) {
99
      inVersions[s] = frontier[s];
100
      analysis.blobInfoMap[s][frontier[s]].inUsages.push_back(i);
101
    }
102
    Analysis::SSA::BlobVersions outVersions;
103
    auto isTemporaryImages = std::vector<int>();
104
    for (auto j = 0; j < op.arg_size(); ++j) {
105
      if (op.arg(j).name() == kMPSCNNOutputIsTempImageArg) {
106
        for (auto k = 0; k < op.arg(j).ints_size(); ++k) {
107
          isTemporaryImages.push_back(op.arg(j).ints(k));
108
        }
109
      }
110
    }
111

112
    for (auto j = 0; j < op.output_size(); j++) {
113
      auto s = op.output(j);
114
      if (frontier.find(s) != frontier.end()) {
115
        frontier[s] += 1;
116
      }
117
      outVersions[s] = frontier[s];
118
      if (outputStorageTypeMap.find(op.type()) != outputStorageTypeMap.end()) {
119
        analysis.blobInfoMap[s][frontier[s]].storageType = outputStorageTypeMap[op.type()][j];
120
      } else if (op.type() == "CopyFromMPSCNN") {
121
        analysis.blobInfoMap[s][frontier[s]].storageType = StorageType::CPU;
122
      } else if (isTemporaryImages.size() > 0) {
123
        if (isTemporaryImages.at(j)) {
124
          analysis.blobInfoMap[s][frontier[s]].storageType = StorageType::MPSTEMPORARYIMAGE;
125
        } else {
126
          analysis.blobInfoMap[s][frontier[s]].storageType = StorageType::MPSIMAGE;
127
        }
128
      } else if (op.type().find("MPSCNN") != std::string::npos) {
129
        analysis.blobInfoMap[s][frontier[s]].storageType = StorageType::MPSTEMPORARYIMAGE;
130
      } else {
131
        analysis.blobInfoMap[s][frontier[s]].storageType = StorageType::CPU;
132
      }
133
      VLOG(2) << op.type() << " outputBlobTypes:" << s << " " << frontier[s]
134
              << " "
135
              << asString(analysis.blobInfoMap[s][frontier[s]].storageType);
136
    }
137
    analysis.ssa.push_back(Analysis::SSA{inVersions, outVersions});
138
  };
139

140
  for (auto i = 0; i < net.op_size(); ++i) {
141
    play(i, net.op(i));
142
  }
143
}
144

145
static void rewriteOutput(OperatorDef* op, int i) {
146
  auto output = op->output(i);
147
  op->set_output(i, output + "_M");
148
}
149

150
static void rewriteInput(OperatorDef* op, int i) {
151
  auto input = op->input(i);
152
  op->set_input(i, input + "_I");
153
}
154

155
static void insertOutputCopyFromMPSCNNOp(NetDef& predictNet, const std::string& cpu_blob) {
156
  auto* op = predictNet.add_op();
157
  op->set_type("CopyFromMPSCNN");
158
  op->add_input(cpu_blob + "_M");
159
  op->add_output(cpu_blob);
160
}
161

162
static void insertInputCopyFromMPSCNNOp(NetDef& predictNet, const std::string& cpu_blob) {
163
  auto* op = predictNet.add_op();
164
  op->set_type("CopyFromMPSCNN");
165
  op->add_input(cpu_blob);
166
  op->add_output(cpu_blob + "_I");
167
}
168

169
static void insertInputCopyToMPSCNNOp(NetDef& predictNet, const std::string& gpu_blob) {
170
  auto* op = predictNet.add_op();
171
  op->set_type("CopyToMPSCNN");
172
  op->add_input(gpu_blob);
173
  op->add_output(gpu_blob + "_I");
174
}
175

176
void commandBufferAnalysis(Analysis& analysis, NetDef& def) {
177
  analysis.currentCommandBufferId = 0;
178
  analysis.blobInfoMap[def.op(0).input(0)][0].commandBufferId = analysis.currentCommandBufferId;
179
  for (auto i = 0; i < def.op_size(); ++i) {
180
    auto op = def.op(i);
181
    if (std::find(opsNeedsSync.begin(), opsNeedsSync.end(), op.type()) != opsNeedsSync.end()) {
182
      analysis.currentCommandBufferId += 1;
183
      for (auto j = 0; j < op.output_size(); ++j) {
184
        auto outputBlob = op.output(j);
185
        auto version = analysis.ssa[i].outVersions[outputBlob];
186
        analysis.blobInfoMap[outputBlob][version].commandBufferId = analysis.currentCommandBufferId;
187
      }
188
    } else {
189
      int inputCommandBufferId = 0;
190
      for (auto j = 0; j < op.input_size(); ++j) {
191
        auto inputBlob = op.input(j);
192
        auto version = analysis.ssa[i].inVersions[inputBlob];
193
        if (analysis.blobInfoMap.find(inputBlob) != analysis.blobInfoMap.end() &&
194
            analysis.blobInfoMap[inputBlob][version].storageType == StorageType::MPSIMAGE) {
195
          analysis.currentCommandBufferId += 1;
196
          inputCommandBufferId = analysis.currentCommandBufferId;
197
        } else {
198
          inputCommandBufferId =
199
              fmax(inputCommandBufferId, analysis.blobInfoMap[inputBlob][version].commandBufferId);
200
        }
201
      }
202
      // command buffer same as input
203
      for (auto j = 0; j < op.output_size(); ++j) {
204
        auto outputBlob = op.output(j);
205
        auto version = analysis.ssa[i].outVersions[outputBlob];
206
        analysis.blobInfoMap[outputBlob][version].commandBufferId = inputCommandBufferId;
207
      }
208
    }
209
    for (auto j = 0; j < op.output_size(); ++j) {
210
      auto outputBlob = op.output(j);
211
      auto version = analysis.ssa[i].outVersions[outputBlob];
212
      VLOG(2) << "command buffer analysis: " << outputBlob << " " << version << " "
213
              << analysis.blobInfoMap[outputBlob][version].commandBufferId;
214
    }
215
  }
216
}
217

218
void analyzeNet(Analysis& analysis, NetDef& net) {
219
  analysis.ssa.clear();
220
  analysis.blobInfoMap.clear();
221
  ssaAnalysis(analysis, net);
222
  commandBufferAnalysis(analysis, net);
223
}
224

225
NetDef mergeCopyFromMPSCNN(Analysis& analysis, NetDef& def) {
226
  analyzeNet(analysis, def);
227
  // command buffer id -> op id
228
  std::unordered_map<int, std::vector<size_t>> commandBufferToOps;
229
  // For CopyFromMPSCNN, find the command buffer id each input blob uses. and
230
  // aggreagate the ops with the same command buffer
231
  for (auto i = 0; i < def.op_size(); ++i) {
232
    auto op = def.op(i);
233
    if (op.type() == "CopyFromMPSCNN") {
234
      auto blobName = op.input(0);
235
      auto version = analysis.ssa[i].inVersions[blobName];
236
      auto commandId = analysis.blobInfoMap[blobName][version].commandBufferId;
237
      VLOG(2) << "Command buffer to ops:" << blobName << " " << version << " " << commandId;
238
      if (commandBufferToOps.find(commandId) == commandBufferToOps.end()) {
239
        commandBufferToOps[commandId] = std::vector<size_t>();
240
      }
241
      commandBufferToOps[commandId].push_back(i);
242
    }
243
  }
244

245
  std::vector<size_t> opsToRemove;
246
  for (auto item : commandBufferToOps) {
247
    auto commandBufferId = item.first;
248
    auto ops = item.second;
249
    if (ops.size() > 1) {
250
      VLOG(2) << "Merging for command buffer:" << commandBufferId;
251
      // Let's use the first input as an indicator whether the data is for
252
      // external output or internal use, if the data used by intermediate node,
253
      // we want to keep the first operator, otherwise, we want to keep
254
      // the last operator.
255
      // [LATER]There might be cases when some of the data is for external output and
256
      // others used by intermediate node, we'll need to have better heuristics
257
      // for these cases.
258
      auto externalUse = false;
259
      auto firstCopy = def.op(ops[0]);
260
      auto firstOutput = firstCopy.output(0);
261
      for (auto i = 0; i < def.external_output_size(); ++i) {
262
        if (def.external_output(i) == firstOutput) {
263
          externalUse = true;
264
        }
265
      }
266
      int removeStart, removeEnd, keepIndex;
267
      if (externalUse) {
268
        // change the last op into the new op and remove the other ops;
269
        removeStart = 0;
270
        removeEnd = ops.size() - 1;
271
        keepIndex = ops[removeEnd];
272
      } else {
273
        removeStart = 1;
274
        removeEnd = ops.size();
275
        keepIndex = ops[removeStart - 1];
276
      }
277
      auto* op = def.mutable_op(keepIndex);
278
      auto inputOutputs = std::set<std::pair<string, string>>();
279
      for (auto i = removeStart; i < removeEnd; ++i) {
280
        auto op0 = def.op(ops[i]);
281
        if (op0.input(0) != op->input(0)) {
282
          inputOutputs.insert(make_pair(op0.input(0), op0.output(0)));
283
        }
284
      }
285
      for (auto inputOutput : inputOutputs) {
286
        op->add_input(inputOutput.first);
287
        op->add_output(inputOutput.second);
288
      }
289
      for (auto i = removeStart; i < removeEnd; ++i) {
290
        opsToRemove.push_back(ops[i]);
291
      }
292
    }
293
  }
294

295
  NetDef mdef;
296
  mdef.CopyFrom(def);
297
  mdef.clear_op();
298
  for (auto i = 0; i < def.op_size(); ++i) {
299
    if (std::find(opsToRemove.begin(), opsToRemove.end(), i) == opsToRemove.end()) {
300
      const auto& ogOp = def.op(i);
301
      auto op = mdef.add_op();
302
      op->CopyFrom(ogOp);
303
    }
304
  }
305
  return mdef;
306
}
307

308
/* Remove the CopyToMPSCNN ops that has the same input/output version
309
 */
310
NetDef mergeCopyToMPSCNN(Analysis& analysis, NetDef& def) {
311
  std::vector<size_t> opsToRemove;
312
  std::set<std::pair<string, size_t>> copiedBlobs;
313
  for (auto i = 0; i < def.op_size(); ++i) {
314
    auto op = def.op(i);
315
    if (def.op(i).type() == "CopyToMPSCNN") {
316
      auto blobName = op.input(0);
317
      auto version = analysis.ssa[i].inVersions[blobName];
318
      auto pair = make_pair(blobName, version);
319
      if (std::find(copiedBlobs.begin(), copiedBlobs.end(), pair) == copiedBlobs.end()) {
320
        copiedBlobs.insert(pair);
321
      } else {
322
        opsToRemove.push_back(i);
323
      }
324
    }
325
  }
326
  NetDef mdef;
327
  mdef.CopyFrom(def);
328
  mdef.clear_op();
329
  for (auto i = 0; i < def.op_size(); ++i) {
330
    if (std::find(opsToRemove.begin(), opsToRemove.end(), i) == opsToRemove.end()) {
331
      const auto& ogOp = def.op(i);
332
      auto op = mdef.add_op();
333
      op->CopyFrom(ogOp);
334
    }
335
  }
336
  return mdef;
337
}
338

339
bool addTempImageArgs(Analysis& analysis, NetDef& def) {
340
  analyzeNet(analysis, def);
341

342
  std::vector<int> synced; // synced command buffer ids;
343
  std::set<std::pair<string, size_t>> mpsImageBlobs; // blobname, version
344

345
  // We want to add temp arg one by one since it changes the command buffer id
346
  // for later operators.
347
  bool found = false;
348
  // identify the images that the command buffer is synced before
349
  for (auto i = 0; i < def.op_size(); ++i) {
350
    auto op = def.op(i);
351
    if (op.type().find("MPSCNN") == string::npos) {
352
      continue;
353
    }
354
    for (auto j = 0; j < op.input_size(); ++j) {
355
      auto inputBlob = op.input(j);
356
      auto version = analysis.ssa[i].inVersions[inputBlob];
357
      auto commandId = analysis.blobInfoMap[inputBlob][version].commandBufferId;
358
      if (std::find(opsNeedsSync.begin(), opsNeedsSync.end(), op.type()) != opsNeedsSync.end()) {
359
        synced.push_back(commandId);
360
        break;
361
      }
362
      if (std::find(synced.begin(), synced.end(), commandId) != synced.end() &&
363
          analysis.blobInfoMap.find(inputBlob) != analysis.blobInfoMap.end() &&
364
          analysis.blobInfoMap[inputBlob][version].storageType == StorageType::MPSTEMPORARYIMAGE) {
365
        VLOG(2) << "mpsimage blob:" << inputBlob << " " << version << " "
366
                << "input " << j << " command: " << commandId;
367
        mpsImageBlobs.insert(make_pair(inputBlob, version));
368
        found = true;
369
      }
370
    }
371
    if (found) {
372
      break;
373
    }
374
  }
375
  // find the blob and add argument
376
  if (found) {
377
    for (auto i = 0; i < def.op_size(); ++i) {
378
      auto op = def.mutable_op(i);
379
      std::vector<int> isTempImages;
380
      bool setArg = false;
381
      for (auto j = 0; j < op->output_size(); ++j) {
382
        auto outputBlob = op->output(j);
383
        auto version = analysis.ssa[i].outVersions[outputBlob];
384
        if (mpsImageBlobs.find(make_pair(outputBlob, version)) != mpsImageBlobs.end()) {
385
          setArg = true;
386
          isTempImages.push_back(0);
387
        } else {
388
          isTempImages.push_back(1);
389
        }
390
      }
391
      if (setArg) {
392
        auto& arg = *(op->add_arg());
393
        arg.set_name(kMPSCNNOutputIsTempImageArg);
394
        for (auto j = 0; j < isTempImages.size(); ++j) {
395
          arg.add_ints(isTempImages[j]);
396
        }
397
      }
398
    }
399
  }
400
  return found;
401
}
402

403
NetDef insertCopies(const NetDef& def) {
404
  // For this version, we insert CopyFromMPSCNN both for
405
  // intermediate nodes and the output node when necessary
406
  CAFFE_ENFORCE_GE(def.external_input_size(), 1);
407
  CAFFE_ENFORCE_GE(def.external_output_size(), 1);
408

409
  Analysis analysis;
410
  ssaAnalysis(analysis, def);
411

412
  CAFFE_ENFORCE_GE(def.op_size(), 1);
413

414
  const auto& outputBlob = def.external_output(0);
415
  const auto& outputBlobVersion = analysis.ssa.back().outVersions[outputBlob];
416

417
  // This should hold true by definition of the SSA analysis.
418
  CAFFE_ENFORCE(analysis.blobInfoMap[outputBlob].find(outputBlobVersion) ==
419
                    analysis.blobInfoMap[outputBlob].end() ||
420
                analysis.blobInfoMap[outputBlob][outputBlobVersion].inUsages.size() == 0);
421
  NetDef mdef;
422
  mdef.CopyFrom(def);
423
  mdef.clear_op();
424

425
  const auto& opKeyList = CPUOperatorRegistry()->Keys();
426
  const auto& opKeySet = std::set<std::string>(opKeyList.begin(), opKeyList.end());
427

428
  for (auto i = 0; i < def.op_size(); ++i) {
429
    const auto& ogOp = def.op(i);
430
    auto inputsToRewrite = std::vector<int>();
431

432
    for (auto j = 0; j < ogOp.input_size(); j++) {
433
      // The blob storage type accepted by the operator
434
      auto expectedBlobType = StorageType::MPSTEMPORARYIMAGE;
435
      // The storage type for blob produced by previous operators
436
      // if it's not produced by previous operators, then it should be network
437
      // parameters which are stored in CPU
438
      auto actualBlobType = StorageType::CPU;
439
      // For non-mpscnn operators, we assume the expected storage type to be CPU
440
      if (ogOp.type().find("MPSCNN") == std::string::npos) {
441
        expectedBlobType = StorageType::CPU;
442
      }
443
      auto inputBlob = ogOp.input(j);
444
      auto version = analysis.ssa[i].inVersions[inputBlob];
445
      // Check whether the blob is produced by previous operators
446
      if (analysis.blobInfoMap.find(inputBlob) != analysis.blobInfoMap.end() &&
447
          analysis.blobInfoMap[inputBlob][version].storageType != StorageType::INVALID) {
448
        actualBlobType = analysis.blobInfoMap[inputBlob][version].storageType;
449
        VLOG(2) << "Found " << inputBlob << " " << j << " with type"
450
                << asString(actualBlobType);
451
      }
452
      if (inputStorageTypeMap.find(ogOp.type()) != inputStorageTypeMap.end()) {
453
        expectedBlobType = inputStorageTypeMap[ogOp.type()][j];
454
      }
455
      if (expectedBlobType != actualBlobType) {
456
        if (expectedBlobType == StorageType::CPU && (isImage(actualBlobType))) {
457
          // copy input(MPSCNN) to input_I(CPU)
458
          insertInputCopyFromMPSCNNOp(mdef, ogOp.input(j));
459
          // rewrite input to input_I for the operator
460
          inputsToRewrite.push_back(j);
461
        } else if (
462
            isImage(expectedBlobType) && actualBlobType == StorageType::CPU) {
463
          insertInputCopyToMPSCNNOp(mdef, ogOp.input(j));
464
          inputsToRewrite.push_back(j);
465
        } // We don't need to insert copies in other cases
466
      }
467
    }
468

469
    auto op = mdef.add_op();
470
    op->CopyFrom(ogOp);
471

472
    for (auto j = 0; j < inputsToRewrite.size(); ++j) {
473
      rewriteInput(op, inputsToRewrite[j]);
474
    }
475

476
    // rewrite name for (single) external input
477
    if (op->type().find("MPSCNN") != std::string::npos &&
478
        opKeySet.find(op->type()) != opKeySet.end()) {
479
      // input used by multiple ops
480
      const auto& inputBlob = def.external_input(0);
481
      if (std::find(analysis.blobInfoMap[inputBlob][0].inUsages.begin(),
482
                    analysis.blobInfoMap[inputBlob][0].inUsages.end(),
483
                    i) != analysis.blobInfoMap[inputBlob][0].inUsages.end()) {
484
        for (auto j = 0; j < op->input_size(); ++j) {
485
          if (op->input(j) == def.external_input(0)) {
486
            op->set_input(j, "__METAL_INPUT_COPY__");
487
          }
488
        }
489
      }
490
    }
491

492
    // if the output is in external output, copy from metal when necessary
493
    for (auto j = 0; j < op->output_size(); ++j) {
494
      for (auto k = 0; k < def.external_output_size(); ++k) {
495
        // Assuming external output blob has unique name, e.g. only version 0
496
        // of the blob is used as the output
497
        if (op->output(j) == def.external_output(k) &&
498
            analysis.blobInfoMap[op->output(j)][0].storageType != StorageType::CPU) {
499
          // copy output_M(MPSCNN) to output(CPU)
500
          insertOutputCopyFromMPSCNNOp(mdef, op->output(j));
501
          // rewrite output to output_M for the operator
502
          rewriteOutput(op, j);
503
        }
504
      }
505
    }
506
  }
507

508
  // Since adding temp image arg changes the result for command buffer analysis,
509
  // which is the analysis the function is based on, we'll add one temp image
510
  // arg at a time and re-run ssa analysis after each and repeat the process
511
  // until convergence
512
  int i = 0;
513
  while (addTempImageArgs(analysis, mdef) && i < 3 * mdef.op_size()) {
514
    i++;
515
  };
516

517
  mdef = mergeCopyFromMPSCNN(analysis, mdef);
518
  mdef = mergeCopyToMPSCNN(analysis, mdef);
519

520
  return mdef;
521
}
522

523
NetDef rewriteForMetalI(const NetDef& def) {
524
  NetDef mdef;
525
  mdef.CopyFrom(def);
526

527
  const auto& opKeyList = CPUOperatorRegistry()->Keys();
528
  const auto& opKeySet = std::set<std::string>(opKeyList.begin(), opKeyList.end());
529
  for (auto i = 0; i < mdef.op_size(); ++i) {
530
    auto* op = mdef.mutable_op(i);
531
    const auto mpscnnOp = std::string("MPSCNN") + op->type();
532
    if (opKeySet.find(mpscnnOp) != opKeySet.end()) {
533
      op->set_type(mpscnnOp);
534
    }
535
  }
536

537
  static std::set<std::string> mpscnnInputOps = {
538
      "CopyToMPSCNN", "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess"};
539

540
  mdef = insertCopies(mdef);
541

542
  mdef = runMPSCNNFusion(mdef);
543

544
  mdef = setSpecialArgs(mdef);
545

546
  CAFFE_ENFORCE_GE(mdef.op_size(), 2);
547
  CAFFE_ENFORCE(mpscnnInputOps.find(mdef.op(0).type()) != mpscnnInputOps.end());
548
  return mdef;
549
}
550
} // namespace
551

552
NetDef setSpecialArgs(const NetDef& def) {
553
  NetDef mdef;
554
  mdef.CopyFrom(def);
555
  for (auto i = 0; i < mdef.op_size(); ++i) {
556
    auto* op = mdef.mutable_op(i);
557
    // setting post_nms_top_N for MPSCNNGenerateProposalsCPP to 36 due to the
558
    // texture array length constraint in RoIWarp
559
    if (op->type() == "MPSCNNGenerateProposalsCPP" || op->type() == "GenerateProposalsCPP") {
560
      auto* arg = op->mutable_arg(0);
561
      arg->set_i(36);
562
    }
563
  }
564
  return mdef;
565
}
566

567
bool tryConvertToMPSCNNIntermediateCopies(const NetDef& initNet,
568
                                          const NetDef& predictNet,
569
                                          NetDef* metalPredictNet) {
570
// iOS 10.0 and above.
571
#define SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(v)                                 \
572
  ([[[UIDevice currentDevice] systemVersion] compare:v options:NSNumericSearch] != \
573
   NSOrderedAscending)
574
#define SYSTEM_VERSION_EQUAL_TO(v) \
575
  ([[[UIDevice currentDevice] systemVersion] compare:v options:NSNumericSearch] == NSOrderedSame)
576

577
  if (!SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"11.0")) {
578
    LOG(ERROR) << "MPSCNN is only supported for ios version above 11.0.";
579
    return false;
580
  }
581
#undef SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO
582
#undef SYSTEM_VERSION_EQUAL_TO
583

584
  // The iOS GPU Family 3 v2 feature set. Introduced with the Apple A9 GPU and iOS 10.0.
585
  // Don't instantiate the MPSCNNContext, as that compiles the kernel source.
586
  if (![MTLCreateSystemDefaultDevice() supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily3_v2]) {
587
    LOG(ERROR) << "The iOS GPU is less than an A9, so MPSCNN is not available";
588
    return false;
589
  }
590

591
  try {
592
    // Instantiating the net and catching failures allows us to
593
    Workspace ws;
594
    ws.RunNetOnce(initNet);
595
    // Throws if unsupported operators are found.
596
    *metalPredictNet = rewriteForMetalI(predictNet);
597
    *metalPredictNet = annotateDefWithReadCounts(*metalPredictNet);
598
    // Throws if unsupported parameters are found.
599
    ws.CreateNet(*metalPredictNet);
600
    LOG(INFO) << "MPSCNN is successfully enabled";
601
    return true;
602
  } catch (const std::exception& e) {
603
    LOG(ERROR) << "Caught exception trying to convert NetDef to MPSCNN: " << e.what();
604
    return false;
605
  }
606
}
607
} // caffe2
608

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.