pytorch

Форк
0
396 строк · 13.2 Кб
1
#include "caffe2/core/operator.h"
2
#include "mpscnn.h"
3
#include "mpscnn_context.h"
4

5
#import <Metal/Metal.h>
6
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
7
#import <UIKit/UIDevice.h>
8

9
namespace caffe2 {
10
struct Analysis {
11
  struct SSA {
12
    using BlobVersions = std::unordered_map<std::string, size_t>;
13
    BlobVersions inVersions;
14
    BlobVersions outVersions;
15
  };
16
  std::vector<SSA> ssa;
17
  std::unordered_map<
18
      std::string,
19
      std::unordered_map<size_t, std::vector<size_t>>>
20
      inUsages;
21
};
22

23
Analysis analyzeNet(const NetDef& net) {
24
  Analysis::SSA::BlobVersions frontier;
25
  Analysis analysis;
26

27
  auto play = [&](size_t i, const OperatorDef& op) {
28
    Analysis::SSA::BlobVersions inVersions;
29
    for (const auto& s : op.input()) {
30
      inVersions[s] = frontier[s];
31
      analysis.inUsages[s][frontier[s]].push_back(i);
32
    }
33
    Analysis::SSA::BlobVersions outVersions;
34
    for (const auto& s : op.output()) {
35
      if (frontier.find(s) != frontier.end()) {
36
        frontier[s] += 1;
37
      }
38
      outVersions[s] = frontier[s];
39
    }
40
    analysis.ssa.push_back(Analysis::SSA{inVersions, outVersions});
41
  };
42

43
  for (auto i = 0; i < net.op_size(); ++i) {
44
    play(i, net.op(i));
45
  }
46
  return analysis;
47
}
48

49
static void rewriteInput(OperatorDef* op, int i) {
50
  auto input = op->input(i);
51
  op->set_input(i, input + "_M");
52
}
53

54
static void rewriteOutput(OperatorDef* op, int i) {
55
  auto output = op->output(i);
56
  op->set_output(i, output + "_M");
57
}
58

59
static void insertOutputCopyFromMPSCNNOp(
60
    NetDef& predictNet,
61
    const std::vector<std::string>& cpu_blobs) {
62
  auto* op = predictNet.add_op();
63
  op->set_type("CopyFromMPSCNN");
64
  for (int i = 0; i < cpu_blobs.size(); ++i) {
65
    op->add_input(cpu_blobs[i] + "_M");
66
    op->add_output(cpu_blobs[i]);
67
  }
68
}
69

70
NetDef insertInputOutputCopyOps(const NetDef& def) {
71
  // Do some validation of the outputs. For this version, we require:
72
  // - a single input (first element of external_input()) is consumed by the
73
  // NetDef - a single output (first element of external_output()) is produced
74
  // by the NetDef. - the input is consumed by def.op(0), and this is the only
75
  // consumer. - the output is produced by def.op(-1).
76
  CAFFE_ENFORCE_GE(def.external_input_size(), 1);
77
  CAFFE_ENFORCE_GE(def.external_output_size(), 1);
78
  auto analysis = analyzeNet(def);
79
  // enforce a single use of the input blob.
80
  CAFFE_ENFORCE_GE(def.op_size(), 1);
81
  const auto& inputBlob = def.external_input(0);
82
  // Enforce that the input blob has a single usage - in the first operator.
83
  CAFFE_ENFORCE(analysis.inUsages[inputBlob][0] == (std::vector<size_t>{0}));
84
  // Enforce that the external_output(0) blob is produced by the last operator
85
  // in this sequence.
86
  const auto& outputBlob = def.external_output(0);
87
  CAFFE_ENFORCE(
88
      analysis.ssa.back().outVersions.find(outputBlob) !=
89
      analysis.ssa.back().outVersions.end());
90
  const auto& outputBlobVersion = analysis.ssa.back().outVersions[outputBlob];
91
  // This should hold true by definition of the SSA analysis.
92
  CAFFE_ENFORCE(
93
      analysis.inUsages[outputBlob].find(outputBlobVersion) ==
94
      analysis.inUsages[outputBlob].end());
95
  NetDef mdef;
96
  mdef.CopyFrom(def);
97
  mdef.clear_op();
98

99
  {
100
    auto& op = *(mdef.add_op());
101
    op.set_type("CopyToMPSCNN");
102
    op.add_input(def.external_input(0));
103
    op.add_output("__METAL_INPUT_COPY__");
104
  }
105

106
  std::unordered_set<std::string> output_set;
107

108
  for (auto i = 0; i < def.op_size(); ++i) {
109
    const auto& ogOp = def.op(i);
110
    auto op = mdef.add_op();
111
    op->CopyFrom(ogOp);
112
    if (i == 0) {
113
      CAFFE_ENFORCE_EQ(op->input(0), def.external_input(0));
114
      op->set_input(0, "__METAL_INPUT_COPY__");
115
    }
116
    /*
117
     * Let's say we have a Blob called "X" that is both the external output
118
     * and will be used in the later operators. And it's on Metal. First, we'll
119
     * rename the output of the operator to "X_M", therefore all the following
120
     * operators that referenced this blob will need to change the input name
121
     * and then we will copy "X_M" to CPU as "X" in the end.
122
     *
123
     */
124
    for (auto j = 0; j < op->input_size(); ++j) {
125
      if (output_set.find(op->input(j)) != output_set.end()) {
126
        rewriteInput(op, j);
127
        // we'll add one CopyFromMPSCNN operator in the end
128
        // to copy all the output blobs from MPSCNN to CPU
129
      }
130
    }
131
    // if the output is in external output, copy from metal when necessary
132
    for (auto j = 0; j < op->output_size(); ++j) {
133
      for (auto k = 0; k < def.external_output_size(); ++k) {
134
        // Assuming external output blob has unique name, e.g. only version 0
135
        // of the blob is used as the output
136
        if (op->output(j) == def.external_output(k)) {
137
          output_set.insert(op->output(j));
138
          // rewrite output to output_M for the operator
139
          rewriteOutput(op, j);
140
        }
141
      }
142
    }
143
  }
144

145
  // We copy all the output from Metal to CPU at once in the end
146
  std::vector<std::string> external_outputs;
147
  for (int i = 0; i < def.external_output_size(); ++i) {
148
    external_outputs.push_back(def.external_output(i));
149
  }
150
  insertOutputCopyFromMPSCNNOp(mdef, external_outputs);
151

152
  return mdef;
153
}
154

155
bool nextIsOnlyUserOfCurrent(
156
    const Analysis& analysis,
157
    size_t currentIdx,
158
    const OperatorDef& currentOp,
159
    const OperatorDef& nextOp) {
160
  CAFFE_ENFORCE_EQ(currentOp.output_size(), 1);
161
  CAFFE_ENFORCE_GE(nextOp.input_size(), 1);
162
  CAFFE_ENFORCE_EQ(currentOp.output(0), nextOp.input(0));
163
  const auto outputName = currentOp.output(0);
164
  // Find the version of the output name we are currently looking at.
165
  // This is guaranteed to exist by SSA analysis.
166
  const auto currentOutputVersion =
167
      analysis.ssa.at(currentIdx).outVersions.at(outputName);
168
  VLOG(2) << "Blob: " << outputName << ", idx: " << currentOutputVersion;
169
  // Find the usages of this in the SSA analysis.
170

171
  // Has this blob every been used?
172
  if (analysis.inUsages.find(outputName) == analysis.inUsages.end()) {
173
    return false;
174
  }
175

176
  // Has this version of the blob ever been used?
177
  if (analysis.inUsages.at(outputName).find(currentOutputVersion) ==
178
      analysis.inUsages.at(outputName).end()) {
179
    return false;
180
  }
181
  const auto currentOutputUsages =
182
      analysis.inUsages.at(outputName).at(currentOutputVersion);
183
  VLOG(2) << "Blob: " << outputName << ", idx: " << currentOutputVersion
184
          << ", usages[0]: " << currentOutputUsages[0];
185

186
  return currentOutputUsages == std::vector<size_t>{currentIdx + 1};
187
}
188
bool tryFuseAdjacentOps(
189
    const Analysis& analysis,
190
    size_t currentIdx,
191
    const OperatorDef& currentOp,
192
    const OperatorDef& nextOp,
193
    OperatorDef* fusedOp) {
194
  // Check for possible invalid opportunities.
195
  // Must be identical outputs, with either in-place usage for nextOp, *or* the
196
  // only use of the output of currentOp is the consumption by nextOp.
197
  if (currentOp.output_size() != 1 || !nextOp.input_size() ||
198
      nextOp.output_size() != 1) {
199
    return false;
200
  }
201

202
  if (currentOp.output(0) != nextOp.input(0)) {
203
    return false;
204
  }
205

206
  if (!nextIsOnlyUserOfCurrent(analysis, currentIdx, currentOp, nextOp)) {
207
    return false;
208
  }
209

210
  // Can we autogenerate this at registration time instead?
211
  static const std::map<std::pair<std::string, std::string>, std::string>
212
      fusionOpportunities = {{
213
          {{"MPSCNNConv", "MPSCNNRelu"}, "MPSCNNConvRelu"},
214
          {{"MPSCNNConv", "MPSCNNSigmoid"}, "MPSCNNConvSigmoid"},
215
          {{"MPSCNNFC", "MPSCNNRelu"}, "MPSCNNFCRelu"},
216
          {{"MPSCNNInstanceNorm", "MPSCNNPRelu"}, "MPSCNNInstanceNormPRelu"},
217
      }};
218
  auto it = fusionOpportunities.find({currentOp.type(), nextOp.type()});
219
  if (it == fusionOpportunities.end()) {
220
    return false;
221
  }
222
  // MPSCNNConvRelu and MPSCNNConvSigmoid cannot be in-place
223
  if (currentOp.type() == "MPSCNNConv" &&
224
      currentOp.input(0) == nextOp.output(0)) {
225
    return false;
226
  }
227
  LOG(INFO) << "Found a fusion between adjacent ops: (" << currentOp.type()
228
            << ", " << nextOp.type() << ") -> " << it->second;
229
  fusedOp->CopyFrom(currentOp);
230
  fusedOp->set_type(it->second);
231
  for (auto i = 1; i < nextOp.input_size(); ++i) {
232
    fusedOp->add_input(nextOp.input(i));
233
  }
234
  fusedOp->set_output(0, nextOp.output(0));
235
  return true;
236
}
237

238
NetDef runMPSCNNFusion(const NetDef& def) {
239
  CAFFE_ENFORCE_GE(def.op_size(), 1);
240
  NetDef mdef;
241
  mdef.CopyFrom(def);
242
  mdef.clear_op();
243
  auto i = 0;
244
  auto analysis = analyzeNet(def);
245

246
  while (i < def.op_size()) {
247
    if (i == def.op_size() - 1) {
248
      VLOG(2) << "Last operator, skipping";
249
      auto* op = mdef.add_op();
250
      op->CopyFrom(def.op(i));
251
      i += 1;
252
      continue;
253
    }
254

255
    const auto& currentOp = def.op(i);
256
    const auto& nextOp = def.op(i + 1);
257
    OperatorDef fusedOp;
258
    if (tryFuseAdjacentOps(analysis, i, currentOp, nextOp, &fusedOp)) {
259
      VLOG(2) << "Found an adjacent fusion at: " << i;
260
      // We can fuse.
261
      auto* op = mdef.add_op();
262
      op->CopyFrom(fusedOp);
263
      i += 2;
264
      continue;
265
    }
266
    VLOG(2) << "No fusion available";
267
    // Just emit the current type.
268
    auto* op = mdef.add_op();
269
    op->CopyFrom(currentOp);
270
    i += 1;
271
  }
272
  return mdef;
273
}
274

275
NetDef rewriteForMetal(const NetDef& def) {
276
  NetDef mdef;
277
  mdef.CopyFrom(def);
278

279
  const auto& opKeyList = CPUOperatorRegistry()->Keys();
280
  const auto& opKeySet =
281
      std::set<std::string>(opKeyList.begin(), opKeyList.end());
282
  for (auto i = 0; i < mdef.op_size(); ++i) {
283
    auto* op = mdef.mutable_op(i);
284
    const auto mpscnnOp = std::string("MPSCNN") + op->type();
285
    CAFFE_ENFORCE(opKeySet.find(mpscnnOp) != opKeySet.end());
286
    op->set_type(mpscnnOp);
287
  }
288

289
  mdef = runMPSCNNFusion(mdef);
290
  static std::set<std::string> mpscnnInputOps = {
291
      "CopyToMPSCNN", "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess"};
292
  static std::set<std::string> mpscnnOutputOps = {
293
      "CopyFromMPSCNN", "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess"};
294

295
  if (mpscnnInputOps.find(mdef.op(0).type()) == mpscnnInputOps.end() &&
296
      mpscnnOutputOps.find(mdef.op(mdef.op_size() - 1).type()) ==
297
          mpscnnOutputOps.end()) {
298
    mdef = insertInputOutputCopyOps(mdef);
299
  }
300
  CAFFE_ENFORCE_GE(mdef.op_size(), 2);
301
  CAFFE_ENFORCE(mpscnnInputOps.find(mdef.op(0).type()) != mpscnnInputOps.end());
302
  CAFFE_ENFORCE(
303
      mpscnnOutputOps.find(mdef.op(mdef.op_size() - 1).type()) !=
304
      mpscnnOutputOps.end());
305
  return mdef;
306
}
307

308
void dumpDef(const NetDef& d) {
309
  for (const auto& op : d.op()) {
310
    LOG(INFO) << op.input(0) << " -> " << op.type() << " -> " << op.output(0);
311
  }
312
}
313

314
NetDef annotateDefWithReadCounts(const NetDef& net) {
315
  // Now we have usage versions, we want to compute, for each blob version, the
316
  // number of usages of each blob version. ReadCount
317
  auto analysis = analyzeNet(net);
318
  using ReadCount = std::unordered_map<std::string, size_t>;
319
  std::vector<ReadCount> readCounts;
320

321
  auto computeReadCount = [&](size_t i, const OperatorDef& op) {
322
    ReadCount rcs;
323
    for (const auto bv : analysis.ssa[i].outVersions) {
324
      const auto versionUsages = analysis.inUsages[bv.first][bv.second];
325
      rcs[bv.first] = versionUsages.size();
326
    }
327
    readCounts.push_back(rcs);
328
  };
329
  for (auto i = 0; i < net.op_size(); ++i) {
330
    computeReadCount(i, net.op(i));
331
  }
332

333
  NetDef annotatedNet;
334
  annotatedNet.CopyFrom(net);
335
  for (auto i = 0; i < annotatedNet.op_size(); ++i) {
336
    auto* op = annotatedNet.mutable_op(i);
337
    // TODO - relax this? CAFFE_ENFORCE_EQ(op->output_size(), 1);
338
    const auto& blob = op->output(0);
339
    const size_t readCount = readCounts[i][blob];
340
    if (readCount > 1) {
341
      auto* arg = op->add_arg();
342
      arg->set_name(kMPSCNNReadCountArg);
343
      arg->set_i(readCount);
344
      LOG(INFO) << "Op: " << i << ", ty: " << op->type() << ", blob: " << blob
345
                << ", read count: " << readCount;
346
    }
347
  }
348
  return annotatedNet;
349
}
350

351
bool tryConvertToMPSCNN(
352
    const NetDef& initNet,
353
    const NetDef& predictNet,
354
    NetDef* metalPredictNet) {
355
  // iOS 10.0 and above.
356

357
#define SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(v) \
358
  ([[[UIDevice currentDevice] systemVersion]       \
359
       compare:v                                   \
360
       options:NSNumericSearch] != NSOrderedAscending)
361
  if (!SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"11.0")) {
362
    LOG(ERROR) << "MPSCNN is only supported for ios version above 11.0.";
363
    return false;
364
  }
365
#undef SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO
366
  // The iOS GPU Family 3 v2 feature set. Introduced with the Apple A9 GPU and
367
  // iOS 10.0. Don't instantiate the MPSCNNContext, as that compiles the kernel
368
  // source.
369
  if (![MTLCreateSystemDefaultDevice()
370
          supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily3_v2]) {
371
    LOG(ERROR) << "The iOS GPU is less than an A9, so MPSCNN is not available";
372
    return false;
373
  }
374

375
  try {
376
    // Instantiating the net and catching failures allows us to
377
    Workspace ws;
378
    ws.RunNetOnce(initNet);
379
    // Throws if unsupported operators are found.
380
    *metalPredictNet = rewriteForMetal(predictNet);
381
    *metalPredictNet = annotateDefWithReadCounts(*metalPredictNet);
382
    // Throws if unsupported parameters are found.
383
    ws.CreateNet(*metalPredictNet);
384
    LOG(INFO) << "MPSCNN is successfully enabled";
385
    return true;
386
  } catch (const std::exception& e) {
387
    LOG(ERROR) << "Caught exception trying to convert NetDef to MPSCNN: "
388
               << e.what();
389
    return false;
390
  }
391
}
392

393
void mpscnnRecordExecutionFinish() {
394
  [getMPSCNNContext().commandQueue insertDebugCaptureBoundary];
395
}
396
}
397

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.