pytorch

Форк
0
2626 строк · 95.2 Кб
1
#include "caffe2/core/common.h"
2
#include "caffe2/core/context.h"
3

4
#if defined(CAFFE2_USE_MPSCNN) && defined(C10_MOBILE)
5

6
#include "caffe2/core/operator.h"
7
#include "caffe2/core/timer.h"
8
#include "caffe2/operators/conv_pool_op_base.h"
9
#include "caffe2/operators/conv_transpose_unpool_op_base.h"
10
#include "caffe2/operators/generate_proposals_op.h"
11
#include "caffe2/operators/generate_proposals_op_util_boxes.h"
12
#include "caffe2/operators/spatial_batch_norm_op.h"
13

14
#include "mpscnn.h"
15
#include "mpscnn_context.h"
16

17
#import <Metal/Metal.h>
18
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
19
#import <UIKit/UIDevice.h>
20

21
#define SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(v) \
22
  ([[[UIDevice currentDevice] systemVersion]       \
23
       compare:v                                   \
24
       options:NSNumericSearch] != NSOrderedAscending)
25

26
// Only compiles against Base SDK iOS 11.0 or greater
27
@interface ConvDataSource : NSObject<MPSCNNConvolutionDataSource>
28
@property float* weights_;
29
@property float* bias_;
30
@property MPSCNNConvolutionDescriptor* desc_;
31
@end
32

33
@implementation ConvDataSource
34
- (id)initWithWeight:(float*)weights
35
                bias:(float*)bias
36
                desc:(MPSCNNConvolutionDescriptor*)desc {
37
  self = [super init];
38
  self.weights_ = weights;
39
  self.bias_ = bias;
40
  self.desc_ = desc;
41
  return self;
42
}
43
- (float*)biasTerms {
44
  return self.bias_;
45
}
46

47
- (MPSDataType)dataType {
48
  return MPSDataTypeFloat32;
49
}
50
- (MPSCNNConvolutionDescriptor*)descriptor {
51
  return self.desc_;
52
}
53
- (NSString*)label {
54
  return nullptr;
55
}
56
- (BOOL)load {
57
  return true;
58
}
59
- (float*)lookupTableForUInt8Kernel {
60
  return nullptr;
61
}
62
- (void)purge {
63
  return;
64
}
65
- (vector_float2*)rangesForUInt8Kernel {
66
  return nullptr;
67
}
68
- (void*)weights {
69
  return self.weights_;
70
}
71

72
- (id)copyWithZone:(NSZone*)zone {
73
  ConvDataSource* newDataSource = [[self class] allocWithZone:zone];
74
  newDataSource.weights_ = self.weights_;
75
  newDataSource.bias_ = self.bias_;
76
  newDataSource.desc_ = self.desc_;
77
  return newDataSource;
78
}
79
@end
80

81
namespace caffe2 {
82

83
namespace {
84
auto divRoundUp(uint x, uint y) -> uint {
85
  return (x + y - 1) / y;
86
}
87

88
MPSTemporaryImage* createTemporaryImage(
89
    const OperatorBase* op,
90
    id<MTLCommandBuffer> commandBuffer,
91
    int n,
92
    int height,
93
    int width,
94
    int channels,
95
    size_t output_idx = 0) {
96
  auto* image = [MPSTemporaryImage
97
      temporaryImageWithCommandBuffer:commandBuffer
98
                      imageDescriptor:
99
                          [MPSImageDescriptor
100
                              imageDescriptorWithChannelFormat:
101
                                  MPSImageFeatureChannelFormatFloat16
102
                                                         width:width
103
                                                        height:height
104
                                               featureChannels:channels
105
                                                numberOfImages:n
106
                                                         usage:
107
                                                             MTLTextureUsageShaderRead |
108
                                                         MTLTextureUsageShaderWrite]];
109
  // We'll try to look at the per-output_idx read-count argument, otherwise,
110
  // we'll use the operator-global default.
111
  const auto& readCounts = op->GetRepeatedArgument<int>(kMPSCNNReadCountArg);
112
  const auto readCount = readCounts.size()
113
      ? readCounts.at(output_idx)
114
      : op->GetSingleArgument<int>(kMPSCNNReadCountArg, 1);
115
  CAFFE_ENFORCE_GE(readCount, 1);
116
  image.readCount = readCount;
117
  return image;
118
}
119

120
MPSImage* createStaticImage(int n, int height, int width, int channels) {
121
  return [[MPSImage alloc]
122
       initWithDevice:getMPSCNNContext().device
123
      imageDescriptor:
124
          [MPSImageDescriptor
125
              imageDescriptorWithChannelFormat:
126
                  MPSImageFeatureChannelFormatFloat16
127
                                         width:width
128
                                        height:height
129
                               featureChannels:channels
130
                                numberOfImages:n
131
                                         usage:MTLTextureUsageShaderRead |
132
                                         MTLTextureUsageShaderWrite]];
133
}
134

135
class MPSImageWrapper {
136
 public:
137
  MPSImageWrapper() {}
138
  MPSImageWrapper(
139
      const OperatorBase* op,
140
      MPSImageWrapper* parent,
141
      int n,
142
      int height,
143
      int width,
144
      int channels,
145
      size_t output_idx = 0) {
146
    /* If the parent wrapper contains a temporary image, we need to pass on the
147
     * command buffer because the temporary images are attached to the command
148
     * buffer, we will need to use the same command buffer in order to use the
149
     * temporary image. We don't want to synchronize the parent wrapper because
150
     * it is still in use. If the parent wrapper contains a static image, we
151
     * should create a new command buffer because we use static image so it can
152
     * survive synchronization(commit of the command buffer), which means if we
153
     * pass on the command buffer the command buffer will be committed in
154
     * multiple places in the graph. Also since we don't pass on parent's
155
     * command buffer,we need to synchronize(commit) it since it won't be used
156
     * in the future.
157
     */
158
    bool passOnCb = parent != nullptr && parent->isTemporaryImage_;
159
    commandBuffer_ = passOnCb ? parent->commandBuffer_
160
                              : [getMPSCNNContext().commandQueue commandBuffer];
161

162
    bool commitInputCb = parent != nullptr && !parent->isTemporaryImage_;
163
    if (commitInputCb) {
164
      parent->synchronize();
165
    }
166

167
    const auto& isTemporaryImages =
168
        op->GetRepeatedArgument<int>(kMPSCNNOutputIsTempImageArg);
169
    isTemporaryImage_ = isTemporaryImages.size()
170
        ? isTemporaryImages.at(output_idx)
171
        : op->GetSingleArgument<int>(kMPSCNNOutputIsTempImageArg, 1);
172
    if (isTemporaryImage_) {
173
      image_ = createTemporaryImage(
174
          op, commandBuffer_, n, height, width, channels, output_idx);
175
    } else {
176
      image_ = createStaticImage(n, height, width, channels);
177
    }
178
  }
179

180
  void markRead() {
181
    if (isTemporaryImage_) {
182
      MPSTemporaryImage* tempImg = (MPSTemporaryImage*)image_;
183
      tempImg.readCount -= 1;
184
    }
185
  }
186

187
  MPSImage* getImage() const {
188
    return image_;
189
  }
190

191
  id<MTLCommandBuffer> getCommandBuffer() const {
192
    return commandBuffer_;
193
  }
194

195
  void synchronize() {
196
    // commit the command buffer if it is notEnqueued
197
    if (commandBuffer_ != nullptr && commandBuffer_.status == 0) {
198
      [commandBuffer_ commit];
199
    }
200
  }
201

202
  void cleanup() {
203
    markRead();
204
    synchronize();
205
  }
206

207
  void copyToOutputBlob(Blob* output) {
208
    output->GetMutable<MPSImageWrapper>()->image_ = image_;
209
    output->GetMutable<MPSImageWrapper>()->commandBuffer_ = commandBuffer_;
210
    output->GetMutable<MPSImageWrapper>()->isTemporaryImage_ =
211
        isTemporaryImage_;
212
  }
213

214
 private:
215
  MPSImage* image_{nullptr};
216
  id<MTLCommandBuffer> commandBuffer_{nullptr};
217
  bool isTemporaryImage_ = true;
218
};
219

220
NSString*
221
kernelFor(const MPSImage* X, NSString* arrayKernel, NSString* nonArrayKernel) {
222
  if (X.featureChannels > 4) {
223
    return arrayKernel;
224
  }
225
  if (X.numberOfImages > 1) {
226
    return arrayKernel;
227
  }
228
  return nonArrayKernel;
229
}
230

231
struct LaunchParams {
232
  MTLSize threadsPerThreadgroup;
233
  MTLSize threadgroupsPerGrid;
234
};
235

236
LaunchParams spatialPointwiseKernelLaunchParams(
237
    id<MTLComputePipelineState> pipeline,
238
    const MPSImage* im) {
239
  const auto maxThreadsPerThreadgroup =
240
      [pipeline maxTotalThreadsPerThreadgroup];
241
  const auto threadExecutionWidth = [pipeline threadExecutionWidth];
242
  const auto threadsPerThreadgroup = MTLSizeMake(
243
      8 /* threadExecutionWidth */,
244
      4 /* maxThreadsPerThreadgroup / threadExecutionWidth */,
245
      1);
246
  const auto threadgroupsPerGrid = MTLSizeMake(
247
      divRoundUp(im.width, threadsPerThreadgroup.width),
248
      divRoundUp(im.height, threadsPerThreadgroup.height),
249
      im.numberOfImages * divRoundUp(im.featureChannels, 4));
250
  return {threadsPerThreadgroup, threadgroupsPerGrid};
251
};
252

253
void computeOutputHW(
254
    ConvPoolOpBase<CPUContext>* op,
255
    int H,
256
    int W,
257
    int* OH,
258
    int* OW) {
259
  Tensor input = caffe2::empty({1, 1, H, W}, at::dtype<float>().device(CPU));
260
  auto sizes = op->GetOutputSize(input, 1);
261
  CAFFE_ENFORCE_EQ(sizes.size(), 4);
262
  *OH = sizes[2];
263
  *OW = sizes[3];
264
}
265

266
constexpr int computeMPSAlignOffset(int kernel, int pad) {
267
  // To set the offset, we can just match the top-left pixel (in the input
268
  // image, with negative values for padding) that we look at. For 3x3s1p1, we
269
  // look at the (-1, -1) pixel in the original impl. For 3x3s1p0, we look at
270
  // (0, 0) pixel. For 3x3s1p2, look at (-2, -2) MPSCNN always looks at
271
  // (-floor(kernel_size - 1 / 2), -floor(kernel_size - 1 / 2)) Thus, we just
272
  // need to match this up.
273

274
  // For 3x3s1p1, offset should be (0, 0)
275
  // For 3x3s1p0, offset should be (1, 1)
276
  // For 3x3s1p2, offset should be (-1, -1)
277
  const int mps_offset = kernel / 2;
278
  const int c2_offset = pad;
279
  return mps_offset - c2_offset;
280
};
281

282
// Compute the 1-d index of a n-dimensional contiguous row-major tensor for
283
//     a given n-dimensional index 'index'
284
size_t ComputeStartIndex(
285
    const TensorCPU& tensor,
286
    const std::vector<int>& index) {
287
  TORCH_DCHECK_EQ(index.size(), tensor.dim());
288

289
  size_t ret = 0;
290
  for (int i = 0; i < index.size(); i++) {
291
    ret += index[i] * tensor.size_from_dim(i + 1);
292
  }
293

294
  return ret;
295
}
296

297
// Get a sub tensor view from 'tensor' using data pointer from 'tensor'
298
template <class T>
299
utils::ConstTensorView<T> GetSubTensorView(
300
    const TensorCPU& tensor,
301
    int dim0_start_index) {
302
  TORCH_DCHECK_EQ(tensor.meta().itemsize(), sizeof(T));
303

304
  if (tensor.size() == 0) {
305
    return utils::ConstTensorView<T>(nullptr, {});
306
  }
307

308
  std::vector<int> start_dims(tensor.dim(), 0);
309
  start_dims.at(0) = dim0_start_index;
310
  auto st_idx = ComputeStartIndex(tensor, start_dims);
311
  auto ptr = tensor.data<T>() + st_idx;
312

313
  auto input_dims = tensor.sizes();
314
  std::vector<int> ret_dims(input_dims.begin() + 1, input_dims.end());
315

316
  utils::ConstTensorView<T> ret(ptr, ret_dims);
317
  return ret;
318
}
319

320
class CopyToMPSCNNOp final : public Operator<CPUContext> {
321
 public:
322
  CopyToMPSCNNOp(const OperatorDef& operator_def, Workspace* ws)
323
      : Operator<CPUContext>(operator_def, ws) {}
324

325
  bool RunOnDevice() override {
326
    inputBuffers_.resize(Inputs().size());
327
    std::vector<MPSImageWrapper> wrappers(Inputs().size());
328
    for (auto i = 0; i < Inputs().size(); ++i) {
329
      const auto& X = Input(i);
330
      CAFFE_ENFORCE(X.dim() > 0 && X.dim() <= 4);
331
      std::vector<int64_t> XDims = {1, 1, 1, 1};
332
      XDims.assign(X.sizes().begin(), X.sizes().end());
333

334
      caffe2::Timer t;
335
      const auto n = XDims[0];
336
      const auto width = XDims[3];
337
      const auto height = XDims[2];
338
      const auto channels = XDims[1];
339
      caffe2::Timer copyT;
340
      if (!inputBuffers_[i] || inputBuffers_[i].length != X.nbytes()) {
341
        inputBuffers_[i] = [getMPSCNNContext().device
342
            newBufferWithLength:X.nbytes()
343
                        options:MTLResourceOptionCPUCacheModeWriteCombined];
344
      }
345
      memcpy([inputBuffers_[i] contents], X.raw_data(), X.nbytes());
346
      VLOG(2) << "CopyToMPSCNNOp input copy took: " << copyT.MilliSeconds();
347
      if (i == 0) {
348
        wrappers[i] =
349
            MPSImageWrapper(this, nullptr, n, height, width, channels, i);
350
      } else {
351
        wrappers[i] =
352
            MPSImageWrapper(this, &wrappers[0], n, height, width, channels, i);
353
      }
354
      auto commandBuffer = wrappers[i].getCommandBuffer();
355
      MPSImage* output = wrappers[i].getImage();
356
      id<MTLComputeCommandEncoder> encoder =
357
          [commandBuffer computeCommandEncoder];
358
      id<MTLComputePipelineState> state =
359
          getMPSCNNContext().getSpecializedPipelineState(
360
              kernelFor(
361
                  output,
362
                  @"copy_nchw_to_metal",
363
                  @"copy_nchw_to_metal_nonarray"),
364
              {{ushort(channels), ushort(height), ushort(width)}});
365
      [encoder setComputePipelineState:state];
366
      [encoder setBuffer:inputBuffers_[i] offset:0 atIndex:0];
367
      [encoder setTexture:[output texture] atIndex:0];
368
      const auto& launchParams =
369
          spatialPointwiseKernelLaunchParams(state, output);
370
      [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
371
              threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
372
      [encoder endEncoding];
373
      VLOG(2) << "CopyToMPSCNNOp took: " << t.MilliSeconds();
374
      wrappers[i].copyToOutputBlob(Outputs()[i]);
375
    }
376
    return true;
377
  }
378

379
 private:
380
  std::vector<id<MTLBuffer>> inputBuffers_;
381
};
382

383
REGISTER_CPU_OPERATOR(CopyToMPSCNN, CopyToMPSCNNOp);
384
OPERATOR_SCHEMA(CopyToMPSCNN)
385
    .NumInputs(1, INT_MAX)
386
    .NumOutputs(1, INT_MAX)
387
    .SameNumberOfOutput();
388

389
auto mpsImageSize = [](MPSImage* X) {
390
  return X.featureChannels * X.height * X.width * X.numberOfImages;
391
};
392

393
class CopyFromMPSCNNOp final : public Operator<CPUContext> {
394
 public:
395
  CopyFromMPSCNNOp(const OperatorDef& operator_def, Workspace* ws)
396
      : Operator<CPUContext>(operator_def, ws) {}
397

398
  bool RunOnDevice() override {
399
    caffe2::Timer t;
400
    auto Wrapper = [&](size_t i) {
401
      return Inputs()[i]->template Get<MPSImageWrapper>();
402
    };
403
    auto cb = [&](size_t i) { return Wrapper(i).getCommandBuffer(); };
404
    auto X = [&](size_t i) { return Wrapper(i).getImage(); };
405

406
    auto cb0 = cb(0);
407
    outputBuffers_.resize(Inputs().size());
408
    for (auto i = 0; i < Inputs().size(); ++i) {
409
      CAFFE_ENFORCE_EQ(cb0, cb(i));
410
      MPSImage* Xi = X(i);
411
      if (!outputBuffers_[i] ||
412
          outputBuffers_[i].length != mpsImageSize(Xi) * sizeof(float)) {
413
        outputBuffers_[i] = [getMPSCNNContext().device
414
            newBufferWithLength:mpsImageSize(Xi) * sizeof(float)
415
                        options:MTLResourceOptionCPUCacheModeDefault];
416
      }
417
      id<MTLComputeCommandEncoder> encoder = [cb0 computeCommandEncoder];
418
      id<MTLComputePipelineState> state =
419
          getMPSCNNContext().getSpecializedPipelineState(
420
              kernelFor(
421
                  Xi, @"copy_metal_to_nchw", @"copy_metal_to_nchw_nonarray"),
422
              {{ushort(Xi.featureChannels),
423
                ushort(Xi.height),
424
                ushort(Xi.width)}});
425

426
      [encoder setComputePipelineState:state];
427
      [encoder setBuffer:outputBuffers_[i] offset:0 atIndex:0];
428
      [encoder setTexture:[Xi texture] atIndex:0];
429

430
      const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Xi);
431
      [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
432
              threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
433
      [encoder endEncoding];
434
      Wrapper(i).markRead();
435
    }
436
    [cb0 commit];
437
    [cb0 waitUntilCompleted];
438

439
    for (auto i = 0; i < Inputs().size(); ++i) {
440
      caffe2::Timer copyOutT;
441
      MPSImage* Xi = X(i);
442
      Output(i)->Resize(
443
          Xi.numberOfImages, Xi.featureChannels, Xi.height, Xi.width);
444
      Output(i)->mutable_data<float>();
445
      CAFFE_ENFORCE_EQ(outputBuffers_[i].length, Output(i)->nbytes());
446
      memcpy(
447
          Output(i)->mutable_data<float>(),
448
          [outputBuffers_[i] contents],
449
          outputBuffers_[i].length);
450
      VLOG(2) << "CopyFromMPSCNNOp memcpy took: " << copyOutT.MilliSeconds();
451
    }
452
    VLOG(2) << "CopyFromMPSCNNOp took: " << t.MilliSeconds();
453
    return true;
454
  }
455

456
 private:
457
  std::vector<id<MTLBuffer>> outputBuffers_;
458
};
459

460
REGISTER_CPU_OPERATOR(CopyFromMPSCNN, CopyFromMPSCNNOp);
461
OPERATOR_SCHEMA(CopyFromMPSCNN)
462
    .NumInputs(1, INT_MAX)
463
    .NumOutputs(1, INT_MAX)
464
    .SameNumberOfOutput();
465

466
class MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp final
467
    : public Operator<CPUContext> {
468
 public:
469
  MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp(
470
      const OperatorDef& operator_def,
471
      Workspace* ws)
472
      : Operator<CPUContext>(operator_def, ws), ws_(ws) {}
473

474
  bool RunOnDevice() override {
475
    const auto& X = Input(0);
476
    const auto& mean = Input(1);
477
    CAFFE_ENFORCE_EQ(mean.size(), 3);
478
    CAFFE_ENFORCE_EQ(X.dim(), 4);
479
    CAFFE_ENFORCE_EQ(X.size(0), 1);
480
    CAFFE_ENFORCE_EQ(X.size(3), 4);
481
    const auto H = X.size(1);
482
    const auto W = X.size(2);
483

484
    caffe2::Timer t;
485

486
    auto* noiseBlob = ws_->CreateBlob("__CAFFE2_STYLIZER_NOISE__");
487
    ushort noiseSize = OperatorBase::GetSingleArgument<int>(
488
        "noise_size", 491 /* prime to avoid artifacts */);
489
    // Treaded as half4 in the kernel, so need half4 here.
490
    noiseSize = divRoundUp(noiseSize, 4) * 4;
491
    if (!BlobIsTensorType(*noiseBlob, CPU) ||
492
        noiseBlob->Get<TensorCPU>().size() != noiseSize) {
493
      VLOG(2) << "Initializing stylizer with noise: " << noiseSize;
494
      caffe2::Timer rt;
495
      // Initialize random noise on first use.
496
      // Cache it to maintain temporal consistency.
497
      auto* t = BlobGetMutableTensor(noiseBlob, CPU);
498
      t->Resize(noiseSize);
499
      math::RandGaussian<float, CPUContext>(
500
          t->size(),
501
          0.0,
502
          OperatorBase::GetSingleArgument<float>("noise_std", 10.0),
503
          t->template mutable_data<float>(),
504
          &context_);
505
      VLOG(2) << "Preprocess initializing noise: " << rt.MilliSeconds();
506
    }
507
    const auto& noise = noiseBlob->Get<TensorCPU>();
508

509
    if (!inputBuffer_ || inputBuffer_.length != X.nbytes()) {
510
      caffe2::Timer pt;
511

512
      inputBuffer_ = [getMPSCNNContext().device
513
          newBufferWithLength:X.nbytes()
514
                      options:MTLResourceOptionCPUCacheModeWriteCombined];
515
      meanBuffer_ = [getMPSCNNContext().device
516
          newBufferWithLength:4 * 2 // (3/4 half-floats).
517
                      options:MTLResourceOptionCPUCacheModeWriteCombined];
518
      noiseBuffer_ = [getMPSCNNContext().device
519
          newBufferWithLength:noiseSize * sizeof(float16_t)
520
                      options:MTLResourceOptionCPUCacheModeWriteCombined];
521

522
      float16_t* meanBufferPtr = (float16_t*)[meanBuffer_ contents];
523
      CAFFE_ENFORCE(meanBufferPtr);
524
      for (auto i = 0; i < mean.size(); ++i) {
525
        meanBufferPtr[i] = mean.data<float>()[i];
526
      }
527
      float16_t* noiseBufferPtr = (float16_t*)[noiseBuffer_ contents];
528
      CAFFE_ENFORCE(noiseBufferPtr);
529
      for (auto i = 0; i < noise.size(); ++i) {
530
        noiseBufferPtr[i] = noise.data<float>()[i];
531
      }
532

533
      VLOG(2) << "Preprocess construct took: " << pt.MilliSeconds();
534
    }
535

536
    {
537
      caffe2::Timer ct;
538
      memcpy([inputBuffer_ contents], X.raw_data(), X.nbytes());
539
      VLOG(2) << "Preprocess memcpy took: " << ct.MilliSeconds();
540
    }
541
    auto outputWrapper = MPSImageWrapper(this, nullptr, 1, H, W, 3);
542
    auto commandBuffer = outputWrapper.getCommandBuffer();
543
    MPSImage* output = outputWrapper.getImage();
544

545
    id<MTLComputeCommandEncoder> encoder =
546
        [commandBuffer computeCommandEncoder];
547
    id<MTLComputePipelineState> state =
548
        getMPSCNNContext().getSpecializedPipelineState(
549
            @"preprocess_stylizer", {noiseSize});
550

551
    [encoder setComputePipelineState:state];
552
    [encoder setBuffer:inputBuffer_ offset:0 atIndex:0];
553
    [encoder setBuffer:meanBuffer_ offset:0 atIndex:1];
554
    [encoder setBuffer:noiseBuffer_ offset:0 atIndex:2];
555

556
    [encoder setTexture:[output texture] atIndex:0];
557
    const auto& launchParams =
558
        spatialPointwiseKernelLaunchParams(state, output);
559
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
560
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
561
    [encoder endEncoding];
562
    outputWrapper.copyToOutputBlob(Outputs()[0]);
563

564
    VLOG(2) << "Preprocess took: " << t.MilliSeconds();
565
    return true;
566
  }
567

568
 private:
569
  Workspace* ws_{nullptr};
570
  id<MTLBuffer> inputBuffer_{nullptr};
571
  id<MTLBuffer> noiseBuffer_{nullptr};
572
  id<MTLBuffer> meanBuffer_{nullptr};
573
};
574

575
REGISTER_CPU_OPERATOR(
576
    MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess,
577
    MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp);
578
OPERATOR_SCHEMA(MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess)
579
    .NumInputs(2)
580
    .NumOutputs(1);
581

582
class MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocessOp final
583
    : public Operator<CPUContext> {
584
 public:
585
  MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocessOp(
586
      const OperatorDef& operator_def,
587
      Workspace* ws)
588
      : Operator<CPUContext>(operator_def, ws) {}
589

590
  bool RunOnDevice() override {
591
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
592
    MPSImage* X = inputWrapper.getImage();
593
    id<MTLCommandBuffer> commandBuffer = inputWrapper.getCommandBuffer();
594

595
    const auto& mean = Input(1);
596
    caffe2::Timer t;
597
    const auto W = X.width;
598
    const auto H = X.height;
599
    CAFFE_ENFORCE_EQ(X.featureChannels, 3);
600
    CAFFE_ENFORCE_EQ(X.numberOfImages, 1);
601

602
    if (!outputBuffer_ || outputBuffer_.length != X.height * X.width * 4) {
603
      caffe2::Timer pt;
604

605
      outputBuffer_ = [getMPSCNNContext().device
606
          newBufferWithLength:X.height * X.width * 4
607
                      options:MTLResourceOptionCPUCacheModeDefault];
608
      meanBuffer_ = [getMPSCNNContext().device
609
          newBufferWithLength:4 * 2 // (3/4 half-floats).
610
                      options:MTLResourceOptionCPUCacheModeWriteCombined];
611
      float16_t* meanBufferPtr = (float16_t*)[meanBuffer_ contents];
612
      for (auto i = 0; i < mean.size(); ++i) {
613
        meanBufferPtr[i] = mean.data<float>()[i];
614
      }
615
      VLOG(2) << "Deprocess copy took: " << pt.MilliSeconds();
616
    }
617
    id<MTLComputeCommandEncoder> encoder =
618
        [commandBuffer computeCommandEncoder];
619
    id<MTLComputePipelineState> state =
620
        getMPSCNNContext().getPipelineState(@"deprocess_stylizer");
621

622
    CAFFE_ENFORCE_EQ(outputBuffer_.length, X.height * X.width * 4);
623
    [encoder setComputePipelineState:state];
624
    [encoder setBuffer:outputBuffer_ offset:0 atIndex:0];
625
    [encoder setBuffer:meanBuffer_ offset:0 atIndex:1];
626
    [encoder setTexture:[X texture] atIndex:0];
627
    const auto& launchParams = spatialPointwiseKernelLaunchParams(state, X);
628
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
629
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
630
    [encoder endEncoding];
631
    inputWrapper.markRead();
632

633
    [commandBuffer commit];
634
    [commandBuffer waitUntilCompleted];
635

636
    Output(0)->Resize(1, X.height, X.width, 4);
637
    {
638
      caffe2::Timer ct;
639
      memcpy(
640
          Output(0)->mutable_data<uint8_t>(),
641
          [outputBuffer_ contents],
642
          [outputBuffer_ length]);
643
      VLOG(2) << "Deprocess copy: " << t.MilliSeconds();
644
    }
645
    CAFFE_ENFORCE_EQ(Output(0)->nbytes(), [outputBuffer_ length]);
646
    VLOG(2) << "Deprocess took: " << t.MilliSeconds();
647

648
    return true;
649
  }
650

651
 private:
652
  id<MTLBuffer> outputBuffer_{nullptr};
653
  id<MTLBuffer> meanBuffer_{nullptr};
654
};
655

656
REGISTER_CPU_OPERATOR(
657
    MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess,
658
    MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocessOp);
659
OPERATOR_SCHEMA(MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess)
660
    .NumInputs(2)
661
    .NumOutputs(1);
662

663
template <typename Neuron>
664
class MPSCNNNeuronOp final : public Operator<CPUContext> {
665
 public:
666
  MPSCNNNeuronOp(const OperatorDef& operator_def, Workspace* ws)
667
      : Operator<CPUContext>(operator_def, ws) {}
668

669
  bool RunOnDevice() override {
670
    caffe2::Timer t;
671
    auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
672
    MPSImage* X = inputWrapper.getImage();
673

674
    auto outputWrapper = MPSImageWrapper(
675
        this,
676
        &inputWrapper,
677
        X.numberOfImages,
678
        X.height,
679
        X.width,
680
        X.featureChannels);
681
    auto commandBuffer = outputWrapper.getCommandBuffer();
682
    MPSImage* output = outputWrapper.getImage();
683
    CAFFE_ENFORCE_EQ(output.width, X.width);
684
    CAFFE_ENFORCE_EQ(output.height, X.height);
685
    CAFFE_ENFORCE_EQ(output.featureChannels, X.featureChannels);
686

687
    if (!neuron_) {
688
      neuron_ = Neuron::t();
689
    }
690
    [neuron_ encodeToCommandBuffer:commandBuffer
691
                       sourceImage:X
692
                  destinationImage:output];
693
    outputWrapper.copyToOutputBlob(Outputs()[0]);
694

695
    VLOG(2) << "ElementwiseAdd took: " << t.MilliSeconds();
696
    return true;
697
  }
698
  MPSCNNNeuron* neuron_{nullptr};
699
};
700

701
#define INIT_NEURON_OP(n)                                          \
702
  REGISTER_CPU_OPERATOR(MPSCNN##n, MPSCNNNeuronOp<n##NeuronInit>); \
703
  OPERATOR_SCHEMA(MPSCNN##n).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
704

705
struct SigmoidNeuronInit {
706
  static MPSCNNNeuron* t() {
707
    return
708
        [[MPSCNNNeuronSigmoid alloc] initWithDevice:getMPSCNNContext().device];
709
  }
710
};
711
INIT_NEURON_OP(Sigmoid);
712

713
struct ReluNeuronInit {
714
  static MPSCNNNeuron* t() {
715
    return
716
        [[MPSCNNNeuronReLU alloc] initWithDevice:getMPSCNNContext().device a:0];
717
  }
718
};
719
INIT_NEURON_OP(Relu);
720

721
struct TanhNeuronInit {
722
  static MPSCNNNeuron* t() {
723
    return [[MPSCNNNeuronTanH alloc] initWithDevice:getMPSCNNContext().device
724
                                                  a:1
725
                                                  b:1];
726
  }
727
};
728
INIT_NEURON_OP(Tanh);
729

730
#undef INIT_NEURON_OP
731

732
template <typename Neuron>
733
class MPSCNNConvOp final : public ConvPoolOpBase<CPUContext> {
734
 public:
735
  MPSCNNConvOp(const OperatorDef& operator_def, Workspace* ws)
736
      : ConvPoolOpBase<CPUContext>(operator_def, ws) {
737
    OPERATOR_NEEDS_FEATURE(
738
        this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
739
    OPERATOR_NEEDS_FEATURE(
740
        kernel_h() == kernel_w(),
741
        "Metal only supports equal kernel dimension.");
742
  }
743

744
  bool RunOnDeviceWithOrderNCHW() override {
745
    caffe2::Timer t;
746
    auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
747
    MPSImage* X = inputWrapper.getImage();
748

749
    auto& filter = Input(FILTER);
750
    auto& bias = Input(BIAS);
751
    CAFFE_ENFORCE_EQ(filter.dim(), 4);
752
    // For NCHW, X.dim32(1), inputChannels
753
    const int C = X.featureChannels;
754
    const int M = filter.dim32(0);
755
    const int Cf = filter.dim32(1);
756

757
    CAFFE_ENFORCE(filter.dim32(2) == kernel_h(), "");
758
    CAFFE_ENFORCE(filter.dim32(3) == kernel_w(), "");
759
    CAFFE_ENFORCE(bias.dim() == 1, "");
760
    CAFFE_ENFORCE(bias.dim32(0) == M, "");
761

762
    const auto kH = kernel_h();
763
    const auto kW = kernel_w();
764

765
    // ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
766
    // Reformat weights from [M][C][kH][kW] to [M][kH][kW][C].
767
    if (!conv_) {
768
      caffe2::Timer consT;
769
      std::vector<float> refilter(M * kH * kW * Cf);
770
      auto* filter_ = filter.template data<float>();
771
      for (auto m = 0; m < M; ++m) {
772
        for (auto c = 0; c < Cf; ++c) {
773
          for (auto kh = 0; kh < kH; ++kh) {
774
            for (auto kw = 0; kw < kW; ++kw) {
775
              // refilter[m][kh][kw][c]
776
              refilter[m * kH * kW * Cf + kh * kW * Cf + kw * Cf + c] =
777
                  // filter[m][c][kh][kw]
778
                  filter_[m * Cf * kH * kW + c * kH * kW + kh * kW + kw];
779
            }
780
          }
781
        }
782
      }
783
      // DepthwiseConv path
784
      bool runtimeAtLeastIOS11 =
785
          SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"11.0");
786
      // Only inputFeatureChannels == outputFeatureChannels is supported right
787
      // now
788
      if (runtimeAtLeastIOS11 && this->group_ > 1 && Cf == 1 &&
789
          M == this->group_) {
790
        MPSCNNDepthWiseConvolutionDescriptor* desc =
791
            [MPSCNNDepthWiseConvolutionDescriptor
792
                cnnConvolutionDescriptorWithKernelWidth:kW
793
                                           kernelHeight:kH
794
                                   inputFeatureChannels:C
795
                                  outputFeatureChannels:M
796
                                           neuronFilter:Neuron::t()];
797
        desc.strideInPixelsX = stride_w();
798
        desc.strideInPixelsY = stride_h();
799
        desc.groups = 1;
800
        auto data_source = [[ConvDataSource alloc]
801
            initWithWeight:refilter.data()
802
                      bias:const_cast<float*>(bias.template data<float>())
803
                      desc:desc];
804
        conv_ =
805
            [[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
806
                                              weights:data_source];
807
      } else {
808
        if (this->group_ > 1) {
809
          CAFFE_ENFORCE_EQ(
810
              Cf % 4,
811
              0,
812
              "MPSCNNConvolution requires number of input \
813
                           channels in each group to be multiple of 4 for \
814
                           group > 1.");
815
        }
816
        MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
817
            cnnConvolutionDescriptorWithKernelWidth:kW
818
                                       kernelHeight:kH
819
                               inputFeatureChannels:C
820
                              outputFeatureChannels:M
821
                                       neuronFilter:Neuron::t()];
822
        desc.strideInPixelsX = stride_w();
823
        desc.strideInPixelsY = stride_h();
824
        desc.groups = this->group_;
825
        auto data_source = [[ConvDataSource alloc]
826
            initWithWeight:refilter.data()
827
                      bias:const_cast<float*>(bias.template data<float>())
828
                      desc:desc];
829
        conv_ =
830
            [[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
831
                                              weights:data_source];
832
      }
833

834
      [conv_ setEdgeMode:MPSImageEdgeModeZero];
835

836
      MPSOffset offset;
837
      offset.x = computeMPSAlignOffset(kW, pad_l());
838
      offset.y = computeMPSAlignOffset(kH, pad_t());
839
      offset.z = 0;
840
      [conv_ setOffset:offset];
841
      VLOG(2) << "MPSCNNConv ConvDesc took: " << consT.MilliSeconds();
842
    }
843

844
    CAFFE_ENFORCE_EQ(conv_.strideInPixelsY, stride_h());
845
    CAFFE_ENFORCE_EQ(conv_.strideInPixelsX, stride_w());
846
    CAFFE_ENFORCE_EQ(conv_.inputFeatureChannels, Cf * this->group_);
847
    CAFFE_ENFORCE_EQ(M % conv_.groups, 0);
848
    CAFFE_ENFORCE_EQ(conv_.outputFeatureChannels, M);
849
    CAFFE_ENFORCE_EQ(conv_.kernelWidth, kW);
850
    CAFFE_ENFORCE_EQ(conv_.kernelHeight, kH);
851

852
    int output_height;
853
    int output_width;
854
    computeOutputHW(this, X.height, X.width, &output_height, &output_width);
855
    int output_channels = M;
856

857
    VLOG(2) << "Output height: " << output_height;
858
    VLOG(2) << "Output width:" << output_width;
859
    VLOG(2) << "Output channels:" << output_channels;
860
    auto outputWrapper = MPSImageWrapper(
861
        this,
862
        &inputWrapper,
863
        X.numberOfImages,
864
        output_height,
865
        output_width,
866
        output_channels);
867
    auto commandBuffer = outputWrapper.getCommandBuffer();
868
    MPSImage* output = outputWrapper.getImage();
869
    CAFFE_ENFORCE_EQ(output.height, output_height);
870
    CAFFE_ENFORCE_EQ(output.width, output_width);
871
    [conv_ encodeToCommandBuffer:commandBuffer
872
                     sourceImage:X
873
                destinationImage:output];
874
    outputWrapper.copyToOutputBlob(Outputs()[0]);
875

876
    VLOG(2) << "MPSCNNConv took: " << t.MilliSeconds();
877
    return true;
878
  }
879

880
  // Input: X, W, b
881
  // Output: Y
882
  INPUT_TAGS(INPUT, FILTER, BIAS);
883

884
  MPSCNNConvolution* conv_{nullptr};
885
};
886

887
// No-op init
888
struct EmptyNeuronInit {
889
  static MPSCNNNeuron* t() {
890
    return nil;
891
  }
892
};
893

894
// We can allow the input weights/bias and output to alias each other,
895
// for example when doing a Conv + out-of-place ReLU, then fusing.
896
#define INIT_CONV_NEURON_OP(name, neuron)                        \
897
  REGISTER_CPU_OPERATOR(name, MPSCNNConvOp<neuron>);             \
898
  OPERATOR_SCHEMA(name).NumInputs(3).NumOutputs(1).AllowInplace( \
899
      {{1, 0}, {2, 0}});
900

901
INIT_CONV_NEURON_OP(MPSCNNConv, EmptyNeuronInit);
902
INIT_CONV_NEURON_OP(MPSCNNConvRelu, ReluNeuronInit);
903
INIT_CONV_NEURON_OP(MPSCNNConvSigmoid, SigmoidNeuronInit);
904

905
#undef INIT_CONV_NEURON_OP
906

907
class MPSCNNPadImageOp final : public ConvPoolOpBase<CPUContext> {
908
 public:
909
  MPSCNNPadImageOp(const OperatorDef& operator_def, Workspace* ws)
910
      : ConvPoolOpBase<CPUContext>(operator_def, ws) {
911
    OPERATOR_NEEDS_FEATURE(
912
        this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
913

914
    OPERATOR_NEEDS_FEATURE(
915
        OperatorBase::GetSingleArgument<string>("mode", "") == "reflect",
916
        "Metal only supports reflection");
917
    kernel_[0] = kernel_[1] = 1;
918
  }
919

920
  bool RunOnDeviceWithOrderNCHW() override {
921
    caffe2::Timer t;
922
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
923
    MPSImage* X = inputWrapper.getImage();
924

925
    const auto pH = pad_t();
926
    const auto pW = pad_l();
927
    const auto output_height = X.height + 2 * pH;
928
    const auto output_width = X.width + 2 * pW;
929
    VLOG(1) << "Output height: " << output_height;
930
    VLOG(1) << "Output width:" << output_width;
931
    VLOG(2) << "Output channels:" << X.featureChannels;
932
    auto outputWrapper = MPSImageWrapper(
933
        this,
934
        &inputWrapper,
935
        X.numberOfImages,
936
        output_height,
937
        output_width,
938
        X.featureChannels);
939
    auto commandBuffer = outputWrapper.getCommandBuffer();
940
    MPSImage* output = outputWrapper.getImage();
941
    CAFFE_ENFORCE_EQ(output.height, output_height);
942
    CAFFE_ENFORCE_EQ(output.width, output_width);
943
    id<MTLComputeCommandEncoder> encoder =
944
        [commandBuffer computeCommandEncoder];
945
    id<MTLComputePipelineState> state =
946
        getMPSCNNContext().getPipelineState(kernelFor(
947
            output, @"reflection_padding", @"reflection_padding_nonarray"));
948
    [encoder setComputePipelineState:state];
949
    [encoder setTexture:[X texture] atIndex:0];
950
    [encoder setTexture:[output texture] atIndex:1];
951
    const auto& launchParams =
952
        spatialPointwiseKernelLaunchParams(state, output);
953
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
954
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
955
    [encoder endEncoding];
956
    inputWrapper.markRead();
957
    outputWrapper.copyToOutputBlob(Outputs()[0]);
958

959
    VLOG(2) << "PadImage took: " << t.MilliSeconds();
960
    return true;
961
  }
962
};
963

964
REGISTER_CPU_OPERATOR(MPSCNNPadImage, MPSCNNPadImageOp);
965
OPERATOR_SCHEMA(MPSCNNPadImage).NumInputs(1).NumOutputs(1);
966

967
class MPSCNNMulOp final : public Operator<CPUContext> {
968
 public:
969
  MPSCNNMulOp(const OperatorDef& operator_def, Workspace* ws)
970
      : Operator<CPUContext>(operator_def, ws) {
971
    OPERATOR_NEEDS_FEATURE(
972
        OperatorBase::GetSingleArgument<int>("broadcast", 0) == 1,
973
        "MPSCNNMul only supports broadcast");
974

975
    OPERATOR_NEEDS_FEATURE(
976
        OperatorBase::HasArgument("axis") == false,
977
        "MPSCNNMul does not support axis");
978
  }
979

980
  bool RunOnDevice() override {
981
    caffe2::Timer t;
982

983
    auto wrapper0 = Inputs()[0]->Get<MPSImageWrapper>();
984
    MPSImage* X0 = wrapper0.getImage();
985

986
    const auto& X1 = Input(1);
987
    CAFFE_ENFORCE_EQ(
988
        X1.dim(),
989
        1,
990
        "MPSCNNMulOp: Only dim == 1 for Input(1) is supported for now");
991

992
    auto X1_ = [getMPSCNNContext().device
993
        newBufferWithBytes:X1.template data<float>()
994
                    length:sizeof(float) * X1.size()
995
                   options:MTLResourceOptionCPUCacheModeDefault];
996

997
    auto outputWrapper = MPSImageWrapper(
998
        this,
999
        &wrapper0,
1000
        X0.numberOfImages,
1001
        X0.height,
1002
        X0.width,
1003
        X0.featureChannels);
1004
    auto commandBuffer = outputWrapper.getCommandBuffer();
1005
    MPSImage* output = outputWrapper.getImage();
1006

1007
    id<MTLComputeCommandEncoder> encoder =
1008
        [commandBuffer computeCommandEncoder];
1009
    id<MTLComputePipelineState> state =
1010
        getMPSCNNContext().getSpecializedPipelineState(
1011
            @"elementwise_mul",
1012
            {{ushort(X0.numberOfImages),
1013
              ushort(X0.featureChannels),
1014
              ushort(X1.dim32(0))}});
1015

1016
    [encoder setComputePipelineState:state];
1017
    [encoder setTexture:[X0 texture] atIndex:0];
1018
    [encoder setBuffer:X1_ offset:0 atIndex:1];
1019
    [encoder setTexture:[output texture] atIndex:2];
1020
    const auto& launchParams =
1021
        spatialPointwiseKernelLaunchParams(state, output);
1022
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1023
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1024
    [encoder endEncoding];
1025
    wrapper0.markRead();
1026
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1027
    VLOG(2) << "ElementwiseMul took: " << t.MilliSeconds();
1028
    return true;
1029
  }
1030
};
1031

1032
REGISTER_CPU_OPERATOR(MPSCNNMul, MPSCNNMulOp);
1033
OPERATOR_SCHEMA(MPSCNNMul).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1034

1035
class MPSCNNSubOp final : public Operator<CPUContext> {
1036
 public:
1037
  MPSCNNSubOp(const OperatorDef& operator_def, Workspace* ws)
1038
      : Operator<CPUContext>(operator_def, ws) {
1039
    OPERATOR_NEEDS_FEATURE(
1040
        OperatorBase::GetSingleArgument<int>("broadcast", 0) == 1,
1041
        "MPSCNNSub only supports broadcast");
1042

1043
    OPERATOR_NEEDS_FEATURE(
1044
        OperatorBase::HasArgument("axis") == false,
1045
        "MPSCNNSub does not support axis");
1046
  }
1047

1048
  bool RunOnDevice() override {
1049
    caffe2::Timer t;
1050

1051
    auto wrapper0 = Inputs()[0]->Get<MPSImageWrapper>();
1052
    MPSImage* X0 = wrapper0.getImage();
1053

1054
    const auto& X1 = Input(1);
1055
    CAFFE_ENFORCE_EQ(
1056
        X1.dim(),
1057
        1,
1058
        "MPSCNNSubOp: Only dim == 1 for Input(1) is supported for now");
1059

1060
    auto X1_ = [getMPSCNNContext().device
1061
        newBufferWithBytes:X1.template data<float>()
1062
                    length:sizeof(float) * X1.size()
1063
                   options:MTLResourceOptionCPUCacheModeDefault];
1064

1065
    auto outputWrapper = MPSImageWrapper(
1066
        this,
1067
        &wrapper0,
1068
        X0.numberOfImages,
1069
        X0.height,
1070
        X0.width,
1071
        X0.featureChannels);
1072
    auto commandBuffer = outputWrapper.getCommandBuffer();
1073
    MPSImage* output = outputWrapper.getImage();
1074

1075
    id<MTLComputeCommandEncoder> encoder =
1076
        [commandBuffer computeCommandEncoder];
1077
    id<MTLComputePipelineState> state =
1078
        getMPSCNNContext().getSpecializedPipelineState(
1079
            @"elementwise_sub",
1080
            {{ushort(X0.numberOfImages),
1081
              ushort(X0.featureChannels),
1082
              ushort(X1.dim32(0))}});
1083

1084
    [encoder setComputePipelineState:state];
1085
    [encoder setTexture:[X0 texture] atIndex:0];
1086
    [encoder setBuffer:X1_ offset:0 atIndex:1];
1087
    [encoder setTexture:[output texture] atIndex:2];
1088
    const auto& launchParams =
1089
        spatialPointwiseKernelLaunchParams(state, output);
1090
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1091
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1092
    [encoder endEncoding];
1093
    wrapper0.markRead();
1094
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1095
    VLOG(2) << "ElementwiseSub took: " << t.MilliSeconds();
1096
    return true;
1097
  }
1098
};
1099

1100
REGISTER_CPU_OPERATOR(MPSCNNSub, MPSCNNSubOp);
1101
OPERATOR_SCHEMA(MPSCNNSub).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1102

1103
class MPSCNNAddOp final : public Operator<CPUContext> {
1104
 public:
1105
  MPSCNNAddOp(const OperatorDef& operator_def, Workspace* ws)
1106
      : Operator<CPUContext>(operator_def, ws) {}
1107

1108
  bool RunOnDevice() override {
1109
    caffe2::Timer t;
1110

1111
    auto wrapper0 = Inputs()[0]->Get<MPSImageWrapper>();
1112
    auto wrapper1 = Inputs()[1]->Get<MPSImageWrapper>();
1113
    MPSImage* X0 = wrapper0.getImage();
1114
    MPSImage* X1 = wrapper1.getImage();
1115
    CAFFE_ENFORCE_EQ(wrapper0.getCommandBuffer(), wrapper1.getCommandBuffer());
1116

1117
    auto outputWrapper = MPSImageWrapper(
1118
        this,
1119
        &wrapper0,
1120
        X0.numberOfImages,
1121
        X0.height,
1122
        X0.width,
1123
        X0.featureChannels);
1124
    auto commandBuffer = outputWrapper.getCommandBuffer();
1125
    MPSImage* output = outputWrapper.getImage();
1126
    CAFFE_ENFORCE_EQ(X1.width, X0.width);
1127
    CAFFE_ENFORCE_EQ(X1.height, X0.height);
1128
    CAFFE_ENFORCE_EQ(X1.featureChannels, X0.featureChannels);
1129
    id<MTLComputeCommandEncoder> encoder =
1130
        [commandBuffer computeCommandEncoder];
1131
    id<MTLComputePipelineState> state = getMPSCNNContext().getPipelineState(
1132
        kernelFor(X0, @"elementwise_add", @"elementwise_add_nonarray"));
1133

1134
    [encoder setComputePipelineState:state];
1135
    [encoder setTexture:[X0 texture] atIndex:0];
1136
    [encoder setTexture:[X1 texture] atIndex:1];
1137
    [encoder setTexture:[output texture] atIndex:2];
1138
    const auto& launchParams =
1139
        spatialPointwiseKernelLaunchParams(state, output);
1140
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1141
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1142
    [encoder endEncoding];
1143
    wrapper0.markRead();
1144
    wrapper1.markRead();
1145
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1146

1147
    VLOG(2) << "ElementwiseAdd took: " << t.MilliSeconds();
1148
    return true;
1149
  }
1150
};
1151

1152
REGISTER_CPU_OPERATOR(MPSCNNAdd, MPSCNNAddOp);
1153
// Not really in-place per-se, but semantically is valid and preserves
1154
// compatibility.
1155
OPERATOR_SCHEMA(MPSCNNAdd).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1156

1157
class MPSCNNAveragePoolOp final : public ConvPoolOpBase<CPUContext> {
1158
 public:
1159
  MPSCNNAveragePoolOp(const OperatorDef& operator_def, Workspace* ws)
1160
      : ConvPoolOpBase<CPUContext>(operator_def, ws) {
1161
    OPERATOR_NEEDS_FEATURE(
1162
        this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
1163
    OPERATOR_NEEDS_FEATURE(
1164
        kernel_h() == kernel_w(),
1165
        "Metal only supports equal kernel dimension.");
1166
  }
1167

1168
  bool RunOnDeviceWithOrderNCHW() override {
1169
    caffe2::Timer t;
1170
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1171
    MPSImage* X = inputWrapper.getImage();
1172

1173
    if (!pool_ || this->global_pooling_) {
1174
      caffe2::Timer consT;
1175
      this->ComputePads({(int)X.height, (int)X.width});
1176
      pool_ =
1177
          [[MPSCNNPoolingAverage alloc] initWithDevice:getMPSCNNContext().device
1178
                                           kernelWidth:kernel_w()
1179
                                          kernelHeight:kernel_h()
1180
                                       strideInPixelsX:stride_w()
1181
                                       strideInPixelsY:stride_h()];
1182

1183
      [pool_ setEdgeMode:MPSImageEdgeModeClamp];
1184
      MPSOffset offset;
1185
      offset.x = computeMPSAlignOffset(kernel_w(), pad_l());
1186
      offset.y = computeMPSAlignOffset(kernel_h(), pad_t());
1187
      offset.z = 0;
1188
      [pool_ setOffset:offset];
1189
      VLOG(2) << "MPSCNNAveragePool PoolDesc took: " << consT.MilliSeconds();
1190
    }
1191

1192
    CAFFE_ENFORCE_EQ(pool_.strideInPixelsY, stride_h());
1193
    CAFFE_ENFORCE_EQ(pool_.strideInPixelsX, stride_w());
1194
    int output_height;
1195
    int output_width;
1196
    computeOutputHW(this, X.height, X.width, &output_height, &output_width);
1197

1198
    VLOG(2) << "Output height: " << output_height;
1199
    VLOG(2) << "Output width:" << output_width;
1200
    VLOG(2) << "Output channels:" << X.featureChannels;
1201
    auto outputWrapper = MPSImageWrapper(
1202
        this,
1203
        &inputWrapper,
1204
        X.numberOfImages,
1205
        output_height,
1206
        output_width,
1207
        X.featureChannels);
1208
    auto commandBuffer = outputWrapper.getCommandBuffer();
1209
    MPSImage* output = outputWrapper.getImage();
1210
    CAFFE_ENFORCE_EQ(output.height, output_height);
1211
    CAFFE_ENFORCE_EQ(output.width, output_width);
1212
    [pool_ encodeToCommandBuffer:commandBuffer
1213
                     sourceImage:X
1214
                destinationImage:output];
1215
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1216

1217
    VLOG(2) << "MPSCNNAveragePool took: " << t.MilliSeconds();
1218
    return true;
1219
  }
1220

1221
  MPSCNNPoolingAverage* pool_{nullptr};
1222
};
1223

1224
REGISTER_CPU_OPERATOR(MPSCNNAveragePool, MPSCNNAveragePoolOp);
1225
OPERATOR_SCHEMA(MPSCNNAveragePool).NumInputs(1).NumOutputs(1);
1226

1227
class MPSCNNMaxPoolOp final : public ConvPoolOpBase<CPUContext> {
1228
 public:
1229
  MPSCNNMaxPoolOp(const OperatorDef& operator_def, Workspace* ws)
1230
      : ConvPoolOpBase<CPUContext>(operator_def, ws) {
1231
    OPERATOR_NEEDS_FEATURE(
1232
        this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
1233
    OPERATOR_NEEDS_FEATURE(
1234
        kernel_h() == kernel_w(),
1235
        "Metal only supports equal kernel dimension.");
1236
  }
1237

1238
  bool RunOnDeviceWithOrderNCHW() override {
1239
    caffe2::Timer t;
1240
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1241
    MPSImage* X = inputWrapper.getImage();
1242

1243
    if (!pool_ || this->global_pooling_) {
1244
      caffe2::Timer consT;
1245
      this->ComputePads({(int)X.height, (int)X.width});
1246
      pool_ = [[MPSCNNPoolingMax alloc] initWithDevice:getMPSCNNContext().device
1247
                                           kernelWidth:kernel_w()
1248
                                          kernelHeight:kernel_h()
1249
                                       strideInPixelsX:stride_w()
1250
                                       strideInPixelsY:stride_h()];
1251

1252
      [pool_ setEdgeMode:MPSImageEdgeModeClamp];
1253
      MPSOffset offset;
1254
      offset.x = computeMPSAlignOffset(kernel_w(), pad_l());
1255
      offset.y = computeMPSAlignOffset(kernel_h(), pad_t());
1256
      offset.z = 0;
1257
      [pool_ setOffset:offset];
1258
      VLOG(2) << "MPSCNNMaxPool PoolDesc took: " << consT.MilliSeconds();
1259
    }
1260

1261
    CAFFE_ENFORCE_EQ(pool_.strideInPixelsY, stride_h());
1262
    CAFFE_ENFORCE_EQ(pool_.strideInPixelsX, stride_w());
1263

1264
    int output_height;
1265
    int output_width;
1266
    computeOutputHW(this, X.height, X.width, &output_height, &output_width);
1267

1268
    VLOG(2) << "Output height: " << output_height;
1269
    VLOG(2) << "Output width:" << output_width;
1270
    VLOG(2) << "Output channels:" << X.featureChannels;
1271
    auto outputWrapper = MPSImageWrapper(
1272
        this,
1273
        &inputWrapper,
1274
        X.numberOfImages,
1275
        output_height,
1276
        output_width,
1277
        X.featureChannels);
1278
    auto commandBuffer = outputWrapper.getCommandBuffer();
1279
    MPSImage* output = outputWrapper.getImage();
1280
    CAFFE_ENFORCE_EQ(output.height, output_height);
1281
    CAFFE_ENFORCE_EQ(output.width, output_width);
1282
    [pool_ encodeToCommandBuffer:commandBuffer
1283
                     sourceImage:X
1284
                destinationImage:output];
1285
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1286

1287
    VLOG(2) << "MPSCNNMaxPool took: " << t.MilliSeconds();
1288
    return true;
1289
  }
1290

1291
  MPSCNNPoolingMax* pool_{nullptr};
1292
};
1293

1294
REGISTER_CPU_OPERATOR(MPSCNNMaxPool, MPSCNNMaxPoolOp);
1295
OPERATOR_SCHEMA(MPSCNNMaxPool).NumInputs(1).NumOutputs(1);
1296

1297
class MPSCNNSoftmaxOp final : public Operator<CPUContext> {
1298
 public:
1299
  MPSCNNSoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
1300
      : Operator<CPUContext>(operator_def, ws) {}
1301

1302
  bool RunOnDevice() override {
1303
    caffe2::Timer t;
1304
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1305
    MPSImage* X = inputWrapper.getImage();
1306
    CAFFE_ENFORCE_EQ(X.height, 1);
1307
    CAFFE_ENFORCE_EQ(X.width, 1);
1308
    if (!softmax_) {
1309
      softmax_ =
1310
          [[MPSCNNSoftMax alloc] initWithDevice:getMPSCNNContext().device];
1311
    }
1312
    auto outputWrapper = MPSImageWrapper(
1313
        this,
1314
        &inputWrapper,
1315
        X.numberOfImages,
1316
        X.height,
1317
        X.width,
1318
        X.featureChannels);
1319
    auto commandBuffer = outputWrapper.getCommandBuffer();
1320
    MPSImage* output = outputWrapper.getImage();
1321
    [softmax_ encodeToCommandBuffer:commandBuffer
1322
                        sourceImage:X
1323
                   destinationImage:output];
1324
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1325
    VLOG(2) << "MPSCNNSoftmax took: " << t.MilliSeconds();
1326
    return true;
1327
  }
1328

1329
  MPSCNNSoftMax* softmax_{nullptr};
1330
};
1331

1332
REGISTER_CPU_OPERATOR(MPSCNNSoftmax, MPSCNNSoftmaxOp);
1333
OPERATOR_SCHEMA(MPSCNNSoftmax).NumInputs(1).NumOutputs(1);
1334

1335
template <typename Neuron>
1336
class MPSCNNFullyConnectedOp final : public Operator<CPUContext> {
1337
 public:
1338
  MPSCNNFullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
1339
      : Operator<CPUContext>(operator_def, ws) {}
1340

1341
  bool RunOnDevice() override {
1342
    caffe2::Timer t;
1343
    auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
1344
    MPSImage* X = inputWrapper.getImage();
1345
    const auto& W = Input(1);
1346
    const auto& b = Input(2);
1347

1348
    const auto input_channels = W.dim32(1) / X.width / X.height;
1349
    CAFFE_ENFORCE_EQ(input_channels, X.featureChannels);
1350
    const auto output_channels = W.dim32(0);
1351
    if (!fc_) {
1352
      const auto M = output_channels;
1353
      const auto kH = X.height;
1354
      const auto kW = X.width;
1355
      const auto C = input_channels;
1356
      std::vector<float> refilter(M * kH * kW * C);
1357
      auto* filter_ = W.template data<float>();
1358
      for (auto m = 0; m < M; ++m) {
1359
        for (auto c = 0; c < C; ++c) {
1360
          for (auto kh = 0; kh < kH; ++kh) {
1361
            for (auto kw = 0; kw < kW; ++kw) {
1362
              // refilter[m][kh][kw][c]
1363
              refilter[m * kH * kW * C + kh * kW * C + kw * C + c] =
1364
                  // filter[m][c][kh][kw]
1365
                  filter_[m * C * kH * kW + c * kH * kW + kh * kW + kw];
1366
            }
1367
          }
1368
        }
1369
      }
1370

1371
      MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
1372
          cnnConvolutionDescriptorWithKernelWidth:X.width
1373
                                     kernelHeight:X.height
1374
                             inputFeatureChannels:input_channels
1375
                            outputFeatureChannels:output_channels
1376
                                     neuronFilter:Neuron::t()];
1377
      auto data_source = [[ConvDataSource alloc]
1378
          initWithWeight:refilter.data()
1379
                    bias:const_cast<float*>(b.template data<float>())
1380
                    desc:desc];
1381
      fc_ = [[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
1382
                                              weights:data_source];
1383
    }
1384
    // Note that X.numberOfImages can change between calls, but X.height and
1385
    // X.width are static by definition.
1386
    VLOG(2) << "MPSCNNFC: " << X.numberOfImages << ", " << X.width << ", "
1387
            << X.height << ", " << X.featureChannels << ", " << output_channels;
1388

1389
    [fc_ setClipRect:MTLRegionMake3D(0, 0, 0, 1, 1, X.numberOfImages)];
1390
    MPSOffset off;
1391
    off.x = X.width / 2;
1392
    off.y = X.height / 2;
1393
    off.z = 0;
1394
    [fc_ setOffset:off];
1395
    auto outputWrapper = MPSImageWrapper(
1396
        this, &inputWrapper, X.numberOfImages, 1, 1, output_channels);
1397
    auto commandBuffer = outputWrapper.getCommandBuffer();
1398
    MPSImage* output = outputWrapper.getImage();
1399
    [fc_ encodeToCommandBuffer:commandBuffer
1400
                   sourceImage:X
1401
              destinationImage:output];
1402
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1403
    VLOG(2) << "MPSCNNFC took: " << t.MilliSeconds();
1404
    return true;
1405
  }
1406

1407
  MPSCNNConvolution* fc_{nullptr};
1408
};
1409

1410
#define INIT_FC_NEURON_OP(name, neuron)                        \
1411
  REGISTER_CPU_OPERATOR(name, MPSCNNFullyConnectedOp<neuron>); \
1412
  OPERATOR_SCHEMA(name).NumInputs(3).NumOutputs(1);
1413

1414
INIT_FC_NEURON_OP(MPSCNNFC, EmptyNeuronInit);
1415
INIT_FC_NEURON_OP(MPSCNNFCRelu, ReluNeuronInit);
1416
#undef INIT_FC_NEURON_OP
1417

1418
class MPSCNNDropoutOp final : public Operator<CPUContext> {
1419
 public:
1420
  MPSCNNDropoutOp(const OperatorDef& operator_def, Workspace* ws)
1421
      : Operator<CPUContext>(operator_def, ws) {}
1422

1423
  // Just pass inputs through, since we assume inference-time only.
1424
  bool RunOnDevice() override {
1425
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1426
    inputWrapper.copyToOutputBlob(Outputs()[0]);
1427
    return true;
1428
  }
1429
};
1430

1431
REGISTER_CPU_OPERATOR(MPSCNNDropout, MPSCNNDropoutOp);
1432
// Never use the second output (the mask).
1433
OPERATOR_SCHEMA(MPSCNNDropout)
1434
    .NumInputs(1)
1435
    .NumOutputs(1, 2)
1436
    .AllowInplace({{0, 0}});
1437

1438
class MPSCNNConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
1439
 public:
1440
  MPSCNNConvTransposeOp(const OperatorDef& operator_def, Workspace* ws)
1441
      : ConvTransposeUnpoolBase<CPUContext>(operator_def, ws) {
1442
    OPERATOR_NEEDS_FEATURE(
1443
        this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
1444
    CAFFE_ENFORCE_EQ(
1445
        kernel_w(), kernel_h(), "Metal only supports equal kernel dimensions");
1446
  }
1447

1448
  bool RunOnDeviceWithOrderNCHW() override {
1449
    caffe2::Timer t;
1450
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1451

1452
    MPSImage* X = inputWrapper.getImage();
1453

1454
    auto& filter = Input(FILTER);
1455
    auto& bias = Input(BIAS);
1456
    CAFFE_ENFORCE(filter.dim(), 4);
1457
    const int output_channels = filter.dim32(1);
1458
    const int input_channels = filter.dim32(0);
1459

1460
    CAFFE_ENFORCE(X.featureChannels == input_channels, "");
1461
    CAFFE_ENFORCE(filter.dim32(2) == kernel_h(), "");
1462
    CAFFE_ENFORCE(filter.dim32(3) == kernel_w(), "");
1463
    CAFFE_ENFORCE(bias.dim() == 1, "");
1464
    CAFFE_ENFORCE(bias.dim32(0) == output_channels, "");
1465

1466
    const auto kH = kernel_h();
1467
    const auto kW = kernel_w();
1468

1469
    int output_height =
1470
        (X.height - 1) * stride_h() + kH - pad_b() - pad_t() + adj_h();
1471
    int output_width =
1472
        (X.width - 1) * stride_w() + kW - pad_l() - pad_r() + adj_w();
1473

1474
    VLOG(2) << "Output height: " << output_height;
1475
    VLOG(2) << "Output width:" << output_width;
1476
    VLOG(2) << "Output channels:" << output_channels;
1477

1478
    auto outputWrapper = MPSImageWrapper(
1479
        this,
1480
        &inputWrapper,
1481
        X.numberOfImages,
1482
        output_height,
1483
        output_width,
1484
        output_channels);
1485
    auto commandBuffer = outputWrapper.getCommandBuffer();
1486

1487
    bool runtimeAtLeastIOS11 = SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"11.0");
1488
    // initialization
1489
    if (!conv_trans_ && !conv_) {
1490
      caffe2::Timer consT;
1491
      std::vector<float> refilter(kH * kW * output_channels * input_channels);
1492
      refilter.assign(kH * kW * output_channels * input_channels, 0.0f);
1493
      TORCH_DCHECK_EQ(refilter.size(), filter.size());
1494
      auto* filter_ = filter.template data<float>();
1495
      // For iOS11+ Reformat weights from WT[IC][OC][kH][kW] to
1496
      // W[OC][kH][kW][IC]; For previous versions, reformat weights
1497
      // to W[kH][kW][OC][IC]
1498
      // Also rotate the weight matrix spatially by 180 degrees
1499
      for (auto oc = 0; oc < output_channels; ++oc) {
1500
        for (auto ic = 0; ic < input_channels; ++ic) {
1501
          for (auto kh = 0; kh < kH; ++kh) {
1502
            for (auto kw = 0; kw < kW; ++kw) {
1503
              const auto inputIdx =
1504
                  ic * output_channels * kH * kW + oc * kH * kW + kh * kW + kw;
1505
              int outputIdx;
1506
              if (runtimeAtLeastIOS11) {
1507
                outputIdx = oc * kH * kW * input_channels +
1508
                    (kH - 1 - kh) * kW * input_channels +
1509
                    (kW - 1 - kw) * input_channels + ic;
1510
              } else {
1511
                outputIdx = kh * kW * output_channels * input_channels +
1512
                    kw * output_channels * input_channels +
1513
                    oc * input_channels + ic;
1514
              }
1515
              TORCH_DCHECK_LT(inputIdx, filter.size());
1516
              TORCH_DCHECK_LT(outputIdx, filter.size());
1517
              refilter[outputIdx] = filter_[inputIdx];
1518
            }
1519
          }
1520
        }
1521
      }
1522
      TORCH_DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW);
1523
      // initialize data structures
1524
      if (runtimeAtLeastIOS11) {
1525
        MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
1526
            cnnConvolutionDescriptorWithKernelWidth:kW
1527
                                       kernelHeight:kH
1528
                               inputFeatureChannels:input_channels
1529
                              outputFeatureChannels:output_channels];
1530
        desc.strideInPixelsX = this->stride_w();
1531
        desc.strideInPixelsY = this->stride_h();
1532
        desc.groups = 1;
1533
        auto data_source = [[ConvDataSource alloc]
1534
            initWithWeight:refilter.data()
1535
                      bias:const_cast<float*>(bias.data<float>())
1536
                      desc:desc];
1537

1538
        conv_trans_ = [[MPSCNNConvolutionTranspose alloc]
1539
            initWithDevice:getMPSCNNContext().device
1540
                   weights:data_source];
1541
        MPSOffset offset;
1542
        offset.x = 0;
1543
        offset.y = 0;
1544
        offset.z = 0;
1545
        [conv_trans_ setOffset:offset];
1546
        // kernel offset + padding offset
1547
        conv_trans_.kernelOffsetX = kW / 2 - kW + 1 + this->pad_l();
1548
        conv_trans_.kernelOffsetY = kH / 2 - kH + 1 + this->pad_t();
1549
        VLOG(2) << "MPSCNNConvTranspose ConvDesc took: "
1550
                << consT.MilliSeconds();
1551
      } else {
1552
        MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
1553
            cnnConvolutionDescriptorWithKernelWidth:1
1554
                                       kernelHeight:1
1555
                               inputFeatureChannels:input_channels
1556
                              outputFeatureChannels:output_channels * kH * kW
1557
                                       neuronFilter:nil];
1558
        // We need to zero-fill the bias here.
1559
        std::vector<float> fakeBias;
1560
        fakeBias.assign(output_channels * kH * kW, 0);
1561

1562
        desc.strideInPixelsX = 1;
1563
        desc.strideInPixelsY = 1;
1564
        auto data_source =
1565
            [[ConvDataSource alloc] initWithWeight:refilter.data()
1566
                                              bias:fakeBias.data()
1567
                                              desc:desc];
1568
        conv_ =
1569
            [[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
1570
                                              weights:data_source];
1571
        [conv_ setEdgeMode:MPSImageEdgeModeZero];
1572
        MPSOffset offset;
1573
        offset.x = 0;
1574
        offset.y = 0;
1575
        offset.z = 0;
1576
        [conv_ setOffset:offset];
1577

1578
        const auto biasBytes = divRoundUp(bias.size(), 4) * 4 * 2;
1579
        biasBuffer_ = [getMPSCNNContext().device
1580
            newBufferWithLength:biasBytes
1581
                        options:MTLResourceOptionCPUCacheModeDefault];
1582
        for (auto i = 0; i < bias.size(); ++i) {
1583
          ((float16_t*)[biasBuffer_ contents])[i] = bias.data<float>()[i];
1584
        }
1585

1586
        VLOG(2) << "MPSCNNConvTranspose ConvDesc took: "
1587
                << consT.MilliSeconds();
1588
      } // data structure initialization
1589
    } // initialization
1590
    CAFFE_ENFORCE((conv_trans_ && !conv_) || (!conv_trans_ && conv_));
1591

1592
    // run the computation
1593
    if (conv_trans_) {
1594
      MPSImage* output = outputWrapper.getImage();
1595
      X = inputWrapper.getImage();
1596
      CAFFE_ENFORCE_EQ(conv_trans_.groups, 1);
1597
      [conv_trans_ encodeToCommandBuffer:commandBuffer
1598
                             sourceImage:X
1599
                        destinationImage:output];
1600
    } else {
1601
      CAFFE_ENFORCE_EQ(conv_.strideInPixelsY, 1);
1602
      CAFFE_ENFORCE_EQ(conv_.strideInPixelsX, 1);
1603
      CAFFE_ENFORCE_EQ(conv_.groups, 1);
1604
      CAFFE_ENFORCE_EQ(conv_.inputFeatureChannels, input_channels);
1605
      CAFFE_ENFORCE_EQ(conv_.outputFeatureChannels, output_channels * kH * kW);
1606
      CAFFE_ENFORCE_EQ(conv_.kernelWidth, 1);
1607
      CAFFE_ENFORCE_EQ(conv_.kernelHeight, 1);
1608
      if (divRoundUp(X.numberOfImages * output_channels * kH * kW, 4) >
1609
          kMetalMaxTextureArrLength) {
1610
        LOG(INFO) << "ConvTranspose " << X.numberOfImages << " "
1611
                  << output_channels << " " << kH << " " << kW;
1612
        LOG(ERROR)
1613
            << "arrayLength exceeds the maximum allowed length in texture";
1614
        inputWrapper.cleanup();
1615
        outputWrapper.cleanup();
1616
        return false;
1617
      }
1618
      VLOG(2) << "ConvTranspose:" << output_channels << " " << kH << " " << kW
1619
              << " " << X.numberOfImages;
1620

1621
      auto gemmed = createTemporaryImage(
1622
          this,
1623
          commandBuffer,
1624
          X.numberOfImages,
1625
          X.height,
1626
          X.width,
1627
          output_channels * kH * kW);
1628
      {
1629
        caffe2::Timer gt;
1630
        [conv_ encodeToCommandBuffer:commandBuffer
1631
                         sourceImage:X
1632
                    destinationImage:gemmed];
1633
        VLOG(2) << "MPSCNNConvTranspose GEMM took: " << gt.MilliSeconds();
1634
      }
1635
      MPSImage* output = outputWrapper.getImage();
1636

1637
      {
1638
        caffe2::Timer cit;
1639
        id<MTLComputePipelineState> state =
1640
            getMPSCNNContext().getSpecializedPipelineState(
1641
                @"col2im",
1642
                {{ushort(kernel_h()),
1643
                  ushort(kernel_w()),
1644
                  ushort(stride_h()),
1645
                  ushort(stride_w()),
1646
                  ushort(pad_l()),
1647
                  ushort(pad_t()),
1648
                  ushort(output.featureChannels),
1649
                  ushort(output.numberOfImages),
1650
                  ushort(gemmed.height),
1651
                  ushort(gemmed.width)}});
1652
        id<MTLComputeCommandEncoder> encoder =
1653
            [commandBuffer computeCommandEncoder];
1654
        [encoder setComputePipelineState:state];
1655
        [encoder setTexture:[gemmed texture] atIndex:0];
1656
        [encoder setTexture:[output texture] atIndex:1];
1657
        [encoder setBuffer:biasBuffer_ offset:0 atIndex:0];
1658
        const auto& launchParams =
1659
            spatialPointwiseKernelLaunchParams(state, output);
1660
        [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1661
                threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1662
        [encoder endEncoding];
1663
        gemmed.readCount -= 1;
1664
        VLOG(2) << "MPSCNNConvTranspose upscaling took: " << cit.MilliSeconds();
1665
      }
1666
    }
1667
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1668
    VLOG(2) << "MPSCNNConvTranspose took: " << t.MilliSeconds();
1669
    return true;
1670
  }
1671

1672
  // Input: X, W, b
1673
  // Output: Y
1674
  INPUT_TAGS(INPUT, FILTER, BIAS);
1675
  MPSCNNConvolutionTranspose* conv_trans_{nullptr};
1676
  id<MTLBuffer> biasBuffer_;
1677
  MPSCNNConvolution* conv_{nullptr};
1678
};
1679

1680
// No-op init
1681
#define INIT_CONV_TRANSPOSE_NEURON_OP(name)           \
1682
  REGISTER_CPU_OPERATOR(name, MPSCNNConvTransposeOp); \
1683
  OPERATOR_SCHEMA(name).NumInputs(3).NumOutputs(1);
1684

1685
INIT_CONV_TRANSPOSE_NEURON_OP(MPSCNNConvTranspose);
1686
#undef INIT_CONV_TRANSPOSE_NEURON_OP
1687

1688
enum class InstanceNormFusionTy {
1689
  NONE,
1690
  PRELU,
1691
};
1692

1693
template <InstanceNormFusionTy fusionTy>
1694
class MPSCNNInstanceNormOp final : public Operator<CPUContext> {
1695
 public:
1696
  MPSCNNInstanceNormOp(const OperatorDef& operator_def, Workspace* ws)
1697
      : Operator<CPUContext>(operator_def, ws) {}
1698

1699
  bool RunOnDevice() override {
1700
    auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
1701
    MPSImage* X = inputWrapper.getImage();
1702

1703
    const auto& scale = Input(1);
1704
    const auto& bias = Input(2);
1705
    CAFFE_ENFORCE_EQ(scale.size(), X.featureChannels);
1706
    CAFFE_ENFORCE_EQ(bias.size(), X.featureChannels);
1707
    const auto scaleBytes = divRoundUp(scale.size(), 4) * 4 * 2;
1708
    if (!scaleBuffer_ || !biasBuffer_ || scaleBuffer_.length != scaleBytes ||
1709
        biasBuffer_.length != scaleBytes) {
1710
      caffe2::Timer cvt;
1711
      // Round-up to nearest multiple of 4,
1712
      // so accesses to X[i * 4 + 3]  in kernel is valid.
1713
      scaleBuffer_ = [getMPSCNNContext().device
1714
          newBufferWithLength:scaleBytes
1715
                      options:MTLResourceOptionCPUCacheModeDefault];
1716
      biasBuffer_ = [getMPSCNNContext().device
1717
          newBufferWithLength:scaleBytes
1718
                      options:MTLResourceOptionCPUCacheModeDefault];
1719
      for (auto i = 0; i < scale.size(); ++i) {
1720
        ((float16_t*)[scaleBuffer_ contents])[i] =
1721
            scale.template data<float>()[i];
1722
      }
1723
      for (auto i = 0; i < bias.size(); ++i) {
1724
        ((float16_t*)[biasBuffer_ contents])[i] =
1725
            bias.template data<float>()[i];
1726
      }
1727
      if (fusionTy == InstanceNormFusionTy::PRELU) {
1728
        const auto& preluWeight = Input(3);
1729
        preluWeightBuffer_ = [getMPSCNNContext().device
1730
            newBufferWithLength:divRoundUp(preluWeight.size(), 4) * 4 * 2
1731
                        options:MTLResourceOptionCPUCacheModeDefault];
1732
        for (auto i = 0; i < preluWeight.size(); ++i) {
1733
          ((float16_t*)[preluWeightBuffer_ contents])[i] =
1734
              preluWeight.template data<float>()[i];
1735
        }
1736
      }
1737
      VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
1738
    }
1739

1740
    auto outputWrapper = MPSImageWrapper(
1741
        this,
1742
        &inputWrapper,
1743
        X.numberOfImages,
1744
        X.height,
1745
        X.width,
1746
        X.featureChannels);
1747
    auto commandBuffer = inputWrapper.getCommandBuffer();
1748
    MPSImage* output = outputWrapper.getImage();
1749

1750
    caffe2::Timer t;
1751
    id<MTLComputeCommandEncoder> encoder =
1752
        [commandBuffer computeCommandEncoder];
1753
    id<MTLComputePipelineState> state =
1754
        getMPSCNNContext().getSpecializedPipelineState(
1755
            kernelFor(X, @"instance_norm", @"instance_norm_nonarray"),
1756
            {{ushort(X.featureChannels),
1757
              fusionTy == InstanceNormFusionTy::PRELU ? ushort(Input(3).size())
1758
                                                      : ushort(0)}});
1759

1760
    [encoder setComputePipelineState:state];
1761
    [encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
1762
    [encoder setBuffer:biasBuffer_ offset:0 atIndex:1];
1763
    [encoder setTexture:[X texture] atIndex:0];
1764
    [encoder setTexture:[output texture] atIndex:1];
1765
    if (fusionTy == InstanceNormFusionTy::PRELU) {
1766
      [encoder setBuffer:preluWeightBuffer_ offset:0 atIndex:2];
1767
    }
1768
    [encoder dispatchThreadgroups:MTLSizeMake(
1769
                                      1,
1770
                                      1,
1771
                                      X.numberOfImages *
1772
                                          divRoundUp(X.featureChannels, 4))
1773
            threadsPerThreadgroup:MTLSizeMake(16, 16, 1)];
1774
    [encoder endEncoding];
1775
    inputWrapper.markRead();
1776
    VLOG(2) << "InstanceNorm took: " << t.MilliSeconds();
1777
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1778

1779
    return true;
1780
  }
1781

1782
 private:
1783
  id<MTLBuffer> scaleBuffer_;
1784
  id<MTLBuffer> biasBuffer_;
1785
  id<MTLBuffer> preluWeightBuffer_;
1786
};
1787

1788
REGISTER_CPU_OPERATOR(
1789
    MPSCNNInstanceNorm,
1790
    MPSCNNInstanceNormOp<InstanceNormFusionTy::NONE>);
1791
OPERATOR_SCHEMA(MPSCNNInstanceNorm).NumInputs(3).NumOutputs(1);
1792
REGISTER_CPU_OPERATOR(
1793
    MPSCNNInstanceNormPRelu,
1794
    MPSCNNInstanceNormOp<InstanceNormFusionTy::PRELU>);
1795
OPERATOR_SCHEMA(MPSCNNInstanceNormPRelu).NumInputs(4).NumOutputs(1);
1796

1797
class MPSCNNNormalizePlanarYUVOp final : public Operator<CPUContext> {
1798
 public:
1799
  MPSCNNNormalizePlanarYUVOp(const OperatorDef& operator_def, Workspace* ws)
1800
      : Operator<CPUContext>(operator_def, ws) {}
1801

1802
  bool RunOnDevice() override {
1803
    auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
1804
    MPSImage* X = inputWrapper.getImage();
1805

1806
    const auto& mean = Input(1);
1807
    const auto& std = Input(2);
1808
    CAFFE_ENFORCE_EQ(mean.size(), X.featureChannels);
1809
    CAFFE_ENFORCE_EQ(std.size(), X.featureChannels);
1810
    const auto scaleBytes = divRoundUp(mean.size(), 4) * 4 * 2;
1811
    if (!scaleBuffer_ || !shiftBuffer_ || scaleBuffer_.length != scaleBytes ||
1812
        shiftBuffer_.length != scaleBytes) {
1813
      caffe2::Timer cvt;
1814
      scaleBuffer_ = [getMPSCNNContext().device
1815
          newBufferWithLength:scaleBytes
1816
                      options:MTLResourceOptionCPUCacheModeDefault];
1817
      shiftBuffer_ = [getMPSCNNContext().device
1818
          newBufferWithLength:scaleBytes
1819
                      options:MTLResourceOptionCPUCacheModeDefault];
1820
      // op computes (X - mean) / std = X * 1/std + (-mean/std)
1821
      // Thus set scale = 1.0/std, shift = (-mean/std)
1822
      for (auto i = 0; i < mean.size(); ++i) {
1823
        ((float16_t*)[scaleBuffer_ contents])[i] =
1824
            1.0 / double(std.template data<float>()[i]);
1825
        ((float16_t*)[shiftBuffer_ contents])[i] =
1826
            double(-mean.template data<float>()[i]) /
1827
            double(std.template data<float>()[i]);
1828
      }
1829
      VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
1830
    }
1831

1832
    auto outputWrapper = MPSImageWrapper(
1833
        this,
1834
        &inputWrapper,
1835
        X.numberOfImages,
1836
        X.height,
1837
        X.width,
1838
        X.featureChannels);
1839
    auto commandBuffer = inputWrapper.getCommandBuffer();
1840
    MPSImage* output = outputWrapper.getImage();
1841

1842
    caffe2::Timer t;
1843
    id<MTLComputeCommandEncoder> encoder =
1844
        [commandBuffer computeCommandEncoder];
1845
    id<MTLComputePipelineState> state =
1846
        getMPSCNNContext().getSpecializedPipelineState(
1847
            kernelFor(X, @"affine", @"affine_nonarray"),
1848
            {ushort(X.featureChannels)});
1849

1850
    [encoder setComputePipelineState:state];
1851
    [encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
1852
    [encoder setBuffer:shiftBuffer_ offset:0 atIndex:1];
1853
    [encoder setTexture:[X texture] atIndex:0];
1854
    [encoder setTexture:[output texture] atIndex:1];
1855
    const auto& launchParams =
1856
        spatialPointwiseKernelLaunchParams(state, output);
1857
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1858
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1859
    [encoder endEncoding];
1860
    inputWrapper.markRead();
1861
    VLOG(2) << "InstanceNorm took: " << t.MilliSeconds();
1862
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1863

1864
    return true;
1865
  }
1866

1867
 private:
1868
  id<MTLBuffer> scaleBuffer_;
1869
  id<MTLBuffer> shiftBuffer_;
1870
};
1871

1872
REGISTER_CPU_OPERATOR(MPSCNNNormalizePlanarYUV, MPSCNNNormalizePlanarYUVOp);
1873
OPERATOR_SCHEMA(MPSCNNNormalizePlanarYUV).NumInputs(3).NumOutputs(1);
1874

1875
class MPSCNNPReluOp final : public Operator<CPUContext> {
1876
 public:
1877
  MPSCNNPReluOp(const OperatorDef& operator_def, Workspace* ws)
1878
      : Operator<CPUContext>(operator_def, ws) {}
1879

1880
  bool RunOnDevice() override {
1881
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1882
    const MPSImage* X = inputWrapper.getImage();
1883

1884
    const auto& scale = Input(1);
1885
    const auto scaleBytes = divRoundUp(scale.size(), 4) * 4 * 2;
1886
    if (!scaleBuffer_ || scaleBuffer_.length != scaleBytes) {
1887
      caffe2::Timer cvt;
1888
      scaleBuffer_ = [getMPSCNNContext().device
1889
          newBufferWithLength:scaleBytes
1890
                      options:MTLResourceOptionCPUCacheModeDefault];
1891
      for (auto i = 0; i < scale.size(); ++i) {
1892
        ((float16_t*)[scaleBuffer_ contents])[i] = scale.data<float>()[i];
1893
      }
1894
      VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
1895
    }
1896

1897
    auto outputWrapper = MPSImageWrapper(
1898
        this,
1899
        &inputWrapper,
1900
        X.numberOfImages,
1901
        X.height,
1902
        X.width,
1903
        X.featureChannels);
1904
    auto commandBuffer = inputWrapper.getCommandBuffer();
1905
    MPSImage* output = outputWrapper.getImage();
1906
    caffe2::Timer t;
1907
    id<MTLComputeCommandEncoder> encoder =
1908
        [commandBuffer computeCommandEncoder];
1909
    id<MTLComputePipelineState> state =
1910
        getMPSCNNContext().getSpecializedPipelineState(
1911
            kernelFor(X, @"prelu_nonshared", @"prelu_nonshared_nonarray"),
1912
            {{ushort(X.featureChannels), ushort(scale.size())}});
1913

1914
    [encoder setComputePipelineState:state];
1915
    [encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
1916
    [encoder setTexture:[X texture] atIndex:0];
1917
    [encoder setTexture:[output texture] atIndex:1];
1918

1919
    const auto& launchParams =
1920
        spatialPointwiseKernelLaunchParams(state, output);
1921
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1922
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1923
    [encoder endEncoding];
1924
    inputWrapper.markRead();
1925
    VLOG(2) << "PRelu took: " << t.MilliSeconds();
1926
    outputWrapper.copyToOutputBlob(Outputs()[0]);
1927

1928
    return true;
1929
  }
1930

1931
 private:
1932
  id<MTLBuffer> scaleBuffer_;
1933
};
1934

1935
REGISTER_CPU_OPERATOR(MPSCNNPRelu, MPSCNNPReluOp);
1936
// Allow in-place isn't *really* valid here, since nothing is in-place for Metal
1937
// texture arrays, but requires re-export.
1938
OPERATOR_SCHEMA(MPSCNNPRelu).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1939

1940
class MPSCNNRoIWarpOp final : public Operator<CPUContext> {
1941
 public:
1942
  MPSCNNRoIWarpOp(const OperatorDef& operator_def, Workspace* ws)
1943
      : Operator<CPUContext>(operator_def, ws),
1944
        spatial_scale_(
1945
            OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
1946
        pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
1947
        pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
1948
        sampling_ratio_(
1949
            OperatorBase::GetSingleArgument<int>("sampling_ratio", -1)) {
1950
    CAFFE_ENFORCE_GT(spatial_scale_, 0);
1951
    CAFFE_ENFORCE_GT(pooled_height_, 0);
1952
    CAFFE_ENFORCE_GT(pooled_width_, 0);
1953
    CAFFE_ENFORCE_GE(sampling_ratio_, 0);
1954
    VLOG(1) << "spatial_scale: " << spatial_scale_;
1955
    VLOG(1) << "pooled_h: " << pooled_height_;
1956
    VLOG(1) << "pooled_w: " << pooled_width_;
1957
    VLOG(1) << "sampling_ratio: " << sampling_ratio_;
1958
  }
1959

1960
  bool RunOnDevice() override {
1961
    caffe2::Timer t;
1962
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1963
    auto X = inputWrapper.getImage();
1964
    CAFFE_ENFORCE_EQ(X.numberOfImages, 1);
1965
    const auto& R = Input(1);
1966
    CAFFE_ENFORCE_EQ(R.dim(), 2);
1967
    CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
1968
    const auto roiBytes = R.dim32(0) * 4 * sizeof(float16_t);
1969
    if (!roiBuffer_ || roiBuffer_.length != roiBytes) {
1970
      caffe2::Timer cvt;
1971
      roiBuffer_ = [getMPSCNNContext().device
1972
          newBufferWithLength:roiBytes
1973
                      options:MTLResourceOptionCPUCacheModeDefault];
1974
    }
1975
    float16_t* roiBuffer = (float16_t*)[roiBuffer_ contents];
1976
    // Help compiler generate vcvt?
1977
    const auto Rdim = R.dim32(1);
1978
    CAFFE_ENFORCE(Rdim == 4 || Rdim == 5);
1979
    auto off = Rdim == 5 ? 1 : 0;
1980
    for (auto i = 0; i < R.dim32(0); ++i) {
1981
      if (Rdim == 5) {
1982
        // only handle batch-size of one, so the batch index must be one.
1983
        CAFFE_ENFORCE_EQ(R.data<float>()[i * Rdim], 0.0);
1984
      }
1985
      roiBuffer[i * 4 + 0] = R.data<float>()[i * Rdim + off + 0];
1986
      roiBuffer[i * 4 + 1] = R.data<float>()[i * Rdim + off + 1];
1987
      roiBuffer[i * 4 + 2] = R.data<float>()[i * Rdim + off + 2];
1988
      roiBuffer[i * 4 + 3] = R.data<float>()[i * Rdim + off + 3];
1989
    }
1990
    auto featureChannels = X.featureChannels;
1991
    VLOG(1) << "RoIWarp input size:" << X.numberOfImages << " "
1992
            << featureChannels << " " << X.height << " " << X.width;
1993
    VLOG(1) << "RoIWarp output size:" << R.dim32(0) << " " << featureChannels
1994
            << " " << pooled_width_ << " " << pooled_height_;
1995
    if (R.dim32(0) <= 0) {
1996
      LOG(ERROR) << "number of RoIs <= 0 in RoIWarp " << R.dim32(0);
1997
      inputWrapper.cleanup();
1998
      return false;
1999
    }
2000
    if (divRoundUp(R.dim32(0) * featureChannels, 4) >
2001
        kMetalMaxTextureArrLength) {
2002
      LOG(INFO) << "MPSCNNRoIWarp " << R.dim32(0) << " " << featureChannels;
2003
      LOG(ERROR) << "arrayLength exceeds the maximum allowed length in texture";
2004
      inputWrapper.cleanup();
2005
      return false;
2006
    }
2007
    auto outputWrapper = MPSImageWrapper(
2008
        this,
2009
        &inputWrapper,
2010
        R.dim32(0),
2011
        pooled_height_,
2012
        pooled_width_,
2013
        featureChannels);
2014
    auto commandBuffer = outputWrapper.getCommandBuffer();
2015
    MPSImage* output = outputWrapper.getImage();
2016
    VLOG(1) << "output: " << output.numberOfImages << ", "
2017
            << output.featureChannels << ", " << output.height << ", "
2018
            << output.width;
2019
    id<MTLComputeCommandEncoder> encoder =
2020
        [commandBuffer computeCommandEncoder];
2021
    id<MTLComputePipelineState> state =
2022
        getMPSCNNContext().getSpecializedPipelineState(
2023
            @"roi_warp",
2024
            {{ushort(spatial_scale_ * 10000),
2025
              ushort(sampling_ratio_),
2026
              ushort(featureChannels),
2027
              ushort(X.numberOfImages),
2028
              ushort(output.numberOfImages)}});
2029

2030
    [encoder setComputePipelineState:state];
2031
    [encoder setBuffer:roiBuffer_ offset:0 atIndex:0];
2032
    [encoder setTexture:[X texture] atIndex:0];
2033
    [encoder setTexture:[output texture] atIndex:1];
2034

2035
    const auto& launchParams =
2036
        spatialPointwiseKernelLaunchParams(state, output);
2037
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2038
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2039
    [encoder endEncoding];
2040
    inputWrapper.markRead();
2041
    VLOG(2) << "RoIWarp took: " << t.MilliSeconds();
2042
    VLOG(1) << "ROIWarp size: " << output.numberOfImages << ", "
2043
            << output.featureChannels << ", " << output.height << ", "
2044
            << output.width;
2045
    outputWrapper.copyToOutputBlob(Outputs()[0]);
2046

2047
    return true;
2048
  }
2049

2050
 private:
2051
  float spatial_scale_;
2052
  int pooled_height_;
2053
  int pooled_width_;
2054
  int sampling_ratio_;
2055

2056
  id<MTLBuffer> roiBuffer_;
2057
};
2058

2059
REGISTER_CPU_OPERATOR(MPSCNNRoIWarp, MPSCNNRoIWarpOp);
2060
OPERATOR_SCHEMA(MPSCNNRoIWarp).NumInputs(2).NumOutputs(1);
2061

2062
class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
2063
 public:
2064
  MPSCNNGenerateProposalsCPPOp(const OperatorDef& operator_def, Workspace* ws)
2065
      : Operator<CPUContext>(operator_def, ws),
2066
        spatial_scale_(
2067
            OperatorBase::GetSingleArgument<float>("spatial_scale", 1.0 / 16)),
2068
        feat_stride_(1.0 / spatial_scale_),
2069
        rpn_pre_nms_topN_(
2070
            OperatorBase::GetSingleArgument<int>("pre_nms_topN", 6000)),
2071
        rpn_post_nms_topN_(
2072
            OperatorBase::GetSingleArgument<int>("post_nms_topN", 300)),
2073
        rpn_nms_thresh_(
2074
            OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7f)),
2075
        rpn_min_size_(OperatorBase::GetSingleArgument<float>("min_size", 16)),
2076
        legacy_plus_one_(
2077
            this->template GetSingleArgument<bool>("legacy_plus_one", true)) {}
2078

2079
  template <class Derived1, class Derived2>
2080
  std::vector<int> nms_metal(
2081
      const Eigen::ArrayBase<Derived1>& proposals, // EArrXXf
2082
      const Eigen::ArrayBase<Derived2>& scores, // EArrXf
2083
      const std::vector<int>& sorted_indices,
2084
      float thresh) const {
2085
    CAFFE_ENFORCE_EQ(proposals.rows(), scores.rows());
2086
    CAFFE_ENFORCE_EQ(proposals.cols(), 4);
2087
    CAFFE_ENFORCE_EQ(scores.cols(), 1);
2088
    CAFFE_ENFORCE_LE(sorted_indices.size(), proposals.rows());
2089

2090
    std::vector<float> proposals_cpu(proposals.size());
2091
    Eigen::Map<ERArrXXf>(
2092
        &proposals_cpu[0], proposals.rows(), proposals.cols()) = proposals;
2093

2094
    int box_num = sorted_indices.size();
2095
    int col_blocks = divRoundUp(box_num, maxThreadsPerThreadgroup);
2096
    auto pre_nms_size = box_num;
2097
    auto preNmsProposalsBuffer_ = [getMPSCNNContext().device
2098
        newBufferWithBytes:proposals_cpu.data()
2099
                    length:proposals.size() * sizeof(float)
2100
                   options:MTLResourceOptionCPUCacheModeDefault];
2101
    auto sortedIndicesBuffer_ = [getMPSCNNContext().device
2102
        newBufferWithBytes:sorted_indices.data()
2103
                    length:pre_nms_size * sizeof(int)
2104
                   options:MTLResourceOptionCPUCacheModeDefault];
2105

2106
    int pose_nms_size = fmin(rpn_post_nms_topN_, pre_nms_size);
2107
    // round pose_nms_size up to the next power of 2
2108
    int batch_size = pow(2, ceil(log(pose_nms_size) / log(2)));
2109

2110
    auto maskBuffer_ = [getMPSCNNContext().device
2111
        newBufferWithLength:batch_size * col_blocks * sizeof(uint32_t)
2112
                    options:MTLResourceOptionCPUCacheModeDefault];
2113
    std::vector<uint32_t> masks(batch_size * col_blocks);
2114

2115
    std::vector<int> keep(pose_nms_size);
2116
    int num_to_keep = 0;
2117
    bool terminate = false;
2118
    std::vector<uint32_t> remv(col_blocks);
2119

2120
    for (int offset = 0; !terminate && offset < box_num; offset += batch_size) {
2121
      auto commandBuffer = [getMPSCNNContext().commandQueue commandBuffer];
2122
      auto encoder = [commandBuffer computeCommandEncoder];
2123
      auto state = getMPSCNNContext().getSpecializedPipelineState(
2124
          @"nms",
2125
          {{ushort(batch_size),
2126
            maxThreadsPerThreadgroup,
2127
            ushort(rpn_nms_thresh_ * 10000),
2128
            ushort(offset)}});
2129
      [encoder setComputePipelineState:state];
2130
      [encoder setBuffer:maskBuffer_ offset:0 atIndex:0];
2131
      [encoder setBuffer:preNmsProposalsBuffer_ offset:0 atIndex:1];
2132
      [encoder setBuffer:sortedIndicesBuffer_ offset:0 atIndex:2];
2133
      const auto threadsPerThreadgroup =
2134
          MTLSizeMake(maxThreadsPerThreadgroup, 1, 1);
2135
      const auto threadgroupsPerGrid = MTLSizeMake(
2136
          divRoundUp(batch_size, maxThreadsPerThreadgroup),
2137
          divRoundUp(box_num, maxThreadsPerThreadgroup),
2138
          1);
2139
      [encoder dispatchThreadgroups:threadgroupsPerGrid
2140
              threadsPerThreadgroup:threadsPerThreadgroup];
2141
      [encoder endEncoding];
2142
      [commandBuffer commit];
2143
      [commandBuffer waitUntilCompleted];
2144
      uint32_t* maskBufferPointer = (uint32_t*)[maskBuffer_ contents];
2145
      std::copy(
2146
          maskBufferPointer,
2147
          maskBufferPointer + (maskBuffer_.length / sizeof(uint32_t)),
2148
          masks.begin());
2149

2150
      for (int i = offset; i < fmin(offset + batch_size, box_num); ++i) {
2151
        int nblock = i / maxThreadsPerThreadgroup;
2152
        int inblock = i % maxThreadsPerThreadgroup;
2153
        if (!(remv[nblock] & (1U << inblock))) {
2154
          keep[num_to_keep++] = sorted_indices[i];
2155
          if (num_to_keep >= pose_nms_size) {
2156
            terminate = true;
2157
            break;
2158
          }
2159
          uint* p = &masks[0] + (i - offset) * col_blocks;
2160
          for (int j = nblock; j < col_blocks; j++) {
2161
            remv[j] |= p[j];
2162
          }
2163
        }
2164
      }
2165
    }
2166
    keep.resize(num_to_keep);
2167
    return keep;
2168
  }
2169
  void ProposalsForOneImage(
2170
      const Eigen::Array3f& im_info,
2171
      const Eigen::Map<const ERMatXf>& all_anchors,
2172
      const utils::ConstTensorView<float>& bbox_deltas_tensor,
2173
      const utils::ConstTensorView<float>& scores_tensor,
2174
      ERArrXXf* out_boxes,
2175
      EArrXf* out_probs) const {
2176
    const auto& pre_nms_topN = rpn_pre_nms_topN_;
2177
    const auto& post_nms_topN = rpn_post_nms_topN_;
2178
    const auto& nms_thresh = rpn_nms_thresh_;
2179
    const auto& min_size = rpn_min_size_;
2180

2181
    // Transpose and reshape predicted bbox transformations to get them
2182
    // into the same order as the anchors:
2183
    //   - bbox deltas will be (4 * A, H, W) format from conv output
2184
    //   - transpose to (H, W, 4 * A)
2185
    //   - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
2186
    //     in slowest to fastest order to match the enumerated anchors
2187
    CAFFE_ENFORCE_EQ(bbox_deltas_tensor.ndim(), 3);
2188
    CAFFE_ENFORCE_EQ(bbox_deltas_tensor.dim(0) % 4, 0);
2189
    auto A = bbox_deltas_tensor.dim(0) / 4;
2190
    auto H = bbox_deltas_tensor.dim(1);
2191
    auto W = bbox_deltas_tensor.dim(2);
2192
    // equivalent to python code
2193
    //  bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape((-1, 4))
2194
    ERArrXXf bbox_deltas(H * W * A, 4);
2195
    Eigen::Map<ERMatXf>(bbox_deltas.data(), H * W, 4 * A) =
2196
        Eigen::Map<const ERMatXf>(bbox_deltas_tensor.data(), A * 4, H * W)
2197
            .transpose();
2198
    CAFFE_ENFORCE_EQ(bbox_deltas.rows(), all_anchors.rows());
2199

2200
    // - scores are (A, H, W) format from conv output
2201
    // - transpose to (H, W, A)
2202
    // - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
2203
    //   to match the order of anchors and bbox_deltas
2204
    CAFFE_ENFORCE_EQ(scores_tensor.ndim(), 3);
2205
    CAFFE_ENFORCE_EQ(scores_tensor.dims(), (vector<int>{A, H, W}));
2206
    // equivalent to python code
2207
    // scores = scores.transpose((1, 2, 0)).reshape((-1, 1))
2208
    EArrXf scores(scores_tensor.size());
2209
    Eigen::Map<ERMatXf>(scores.data(), H * W, A) =
2210
        Eigen::Map<const ERMatXf>(scores_tensor.data(), A, H * W).transpose();
2211
    // Transform anchors into proposals via bbox transformations
2212
    auto proposals = utils::bbox_transform(
2213
        all_anchors.array(),
2214
        bbox_deltas,
2215
        std::vector<float>{1.0, 1.0, 1.0, 1.0},
2216
        utils::BBOX_XFORM_CLIP_DEFAULT,
2217
        legacy_plus_one_);
2218

2219
    // 2. clip proposals to image (may result in proposals with zero area
2220
    // that will be removed in the next step)
2221
    proposals = utils::clip_boxes(
2222
        proposals, im_info[0], im_info[1], 1.0, legacy_plus_one_);
2223

2224
    // 3. remove predicted boxes with either height or width < min_size
2225
    auto keep =
2226
        utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_);
2227

2228
    TORCH_DCHECK_LE(keep.size(), scores.size());
2229

2230
    // 4. sort all (proposal, score) pairs by score from highest to lowest
2231
    // 5. take top pre_nms_topN (e.g. 6000)
2232
    std::sort(keep.begin(), keep.end(), [&scores](int lhs, int rhs) {
2233
      return scores[lhs] > scores[rhs];
2234
    });
2235

2236
    if (pre_nms_topN > 0 && pre_nms_topN < keep.size()) {
2237
      keep.resize(pre_nms_topN);
2238
    }
2239

2240
    // 6. apply loose nms (e.g. threshold = 0.7)
2241
    // 7. take after_nms_topN (e.g. 300)
2242
    // 8. return the top proposals (-> RoIs top)
2243
    keep = nms_metal(proposals, scores, keep, nms_thresh);
2244
    if (post_nms_topN > 0 && post_nms_topN < keep.size()) {
2245
      keep.resize(post_nms_topN);
2246
    }
2247
    // Generate outputs
2248
    utils::GetSubArrayRows(proposals, utils::AsEArrXt(keep), out_boxes);
2249
    utils::GetSubArray(scores, utils::AsEArrXt(keep), out_probs);
2250
  }
2251

2252
  bool RunOnDevice() override {
2253
    const auto& scores = Input(0);
2254
    const auto& bbox_deltas = Input(1);
2255
    const auto& im_info_tensor = Input(2);
2256
    const auto& anchors = Input(3);
2257
    auto* out_rois = Output(0);
2258
    auto* out_rois_probs = Output(1);
2259

2260
    CAFFE_ENFORCE_EQ(scores.dim(), 4, scores.dim());
2261
    CAFFE_ENFORCE(scores.template IsType<float>(), scores.meta().name());
2262
    const auto num_images = scores.size(0);
2263
    const auto A = scores.size(1);
2264
    const auto height = scores.size(2);
2265
    const auto width = scores.size(3);
2266
    const auto K = height * width;
2267

2268
    // bbox_deltas: (num_images, A * 4, H, W)
2269
    CAFFE_ENFORCE_EQ(
2270
        bbox_deltas.sizes(), (vector<int64_t>{num_images, 4 * A, height, width}));
2271

2272
    // im_info_tensor: (num_images, 3), format [height, width, scale; ...]
2273
    CAFFE_ENFORCE_EQ(im_info_tensor.sizes(), (vector<int64_t>{num_images, 3}));
2274
    CAFFE_ENFORCE(
2275
        im_info_tensor.template IsType<float>(), im_info_tensor.meta().name());
2276

2277
    // anchors: (A, 4)
2278
    CAFFE_ENFORCE_EQ(anchors.sizes(), (vector<int64_t>{A, 4}));
2279
    CAFFE_ENFORCE(anchors.template IsType<float>(), anchors.meta().name());
2280
    // Broadcast the anchors to all pixels
2281
    auto all_anchors_vec =
2282
        utils::ComputeAllAnchors(anchors, height, width, feat_stride_);
2283
    Eigen::Map<const ERMatXf> all_anchors(all_anchors_vec.data(), K * A, 4);
2284

2285
    Eigen::Map<const ERArrXXf> im_info(
2286
        im_info_tensor.data<float>(),
2287
        im_info_tensor.size(0),
2288
        im_info_tensor.size(1));
2289

2290
    const int roi_col_count = 5;
2291
    out_rois->Resize(0, roi_col_count);
2292
    out_rois_probs->Resize(0);
2293
    Timer t1;
2294
    // Use openmp for acceleration?
2295
    for (int i = 0; i < num_images; i++) {
2296
      auto cur_im_info = im_info.row(i);
2297
      auto cur_bbox_deltas = GetSubTensorView<float>(bbox_deltas, i);
2298
      auto cur_scores = GetSubTensorView<float>(scores, i);
2299

2300
      ERArrXXf im_i_boxes;
2301
      EArrXf im_i_probs;
2302
      ProposalsForOneImage(
2303
          cur_im_info,
2304
          all_anchors,
2305
          cur_bbox_deltas,
2306
          cur_scores,
2307
          &im_i_boxes,
2308
          &im_i_probs);
2309

2310
      int csz = im_i_boxes.rows();
2311
      int cur_start_idx = out_rois->size(0);
2312

2313
      out_rois->Extend(csz, 50);
2314
      out_rois_probs->Extend(csz, 50);
2315

2316
      // write rois
2317
      Eigen::Map<ERArrXXf> cur_rois(
2318
          out_rois->mutable_data<float>() + cur_start_idx * roi_col_count,
2319
          csz,
2320
          5);
2321
      cur_rois.col(0).setConstant(i);
2322
      cur_rois.block(0, 1, csz, 4) = im_i_boxes;
2323

2324
      // write rois_probs
2325
      Eigen::Map<EArrXf>(
2326
          out_rois_probs->mutable_data<float>() + cur_start_idx, csz) =
2327
          im_i_probs;
2328
    }
2329

2330
    return true;
2331
  }
2332

2333
 protected:
2334
  // spatial_scale_ must be declared before feat_stride_
2335
  float spatial_scale_{1.0};
2336
  float feat_stride_{1.0};
2337

2338
  // RPN_PRE_NMS_TOP_N
2339
  ushort rpn_pre_nms_topN_{6000};
2340
  // RPN_POST_NMS_TOP_N
2341
  ushort rpn_post_nms_topN_{300};
2342
  // RPN_NMS_THRESH
2343
  float rpn_nms_thresh_{0.7};
2344
  // RPN_MIN_SIZE
2345
  float rpn_min_size_{16};
2346
  // The infamous "+ 1" for box width and height dating back to the DPM days
2347
  bool legacy_plus_one_{true};
2348
  // threads per thread group, used in nms
2349
  ushort maxThreadsPerThreadgroup{32};
2350
};
2351

2352
REGISTER_CPU_OPERATOR(MPSCNNGenerateProposalsCPP, MPSCNNGenerateProposalsCPPOp);
2353
OPERATOR_SCHEMA(MPSCNNGenerateProposalsCPP).NumInputs(4).NumOutputs(2);
2354

2355
class MPSCNNSpatialBNOp final : public SpatialBNOp<CPUContext> {
2356
 public:
2357
  MPSCNNSpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
2358
      : SpatialBNOp<CPUContext>(operator_def, ws) {}
2359

2360
  bool RunOnDevice() override {
2361
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
2362
    const MPSImage* X = inputWrapper.getImage();
2363
    const auto& scale = Input(SCALE);
2364
    const auto& bias = Input(BIAS);
2365
    const auto& var = Input(EST_VAR);
2366
    const auto& mean = Input(EST_MEAN);
2367
    CAFFE_ENFORCE_EQ(scale.size(), X.featureChannels);
2368
    CAFFE_ENFORCE_EQ(bias.size(), X.featureChannels);
2369
    CAFFE_ENFORCE_EQ(var.size(), X.featureChannels);
2370
    CAFFE_ENFORCE_EQ(mean.size(), X.featureChannels);
2371

2372
    const auto scaleBytes = divRoundUp(scale.size(), 4) * 4 * 2;
2373
    if (!scaleBuffer_ || scaleBuffer_.length != scaleBytes) {
2374
      caffe2::Timer cvt;
2375
      scaleBuffer_ = [getMPSCNNContext().device
2376
          newBufferWithLength:scaleBytes
2377
                      options:MTLResourceOptionCPUCacheModeDefault];
2378
      shiftBuffer_ = [getMPSCNNContext().device
2379
          newBufferWithLength:scaleBytes
2380
                      options:MTLResourceOptionCPUCacheModeDefault];
2381
      for (auto i = 0; i < scale.size(); ++i) {
2382
        // We can fuse the output computation as follows:
2383
        //   ((x - est_mean) * (inv_var) * scale + bias
2384
        // to
2385
        //   (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
2386

2387
        const auto inv_std = 1.0 / std::sqrt(var.data<float>()[i] + epsilon_);
2388
        ((float16_t*)[scaleBuffer_ contents])[i] =
2389
            scale.data<float>()[i] * inv_std;
2390
        ((float16_t*)[shiftBuffer_ contents])[i] = bias.data<float>()[i] -
2391
            mean.data<float>()[i] * inv_std * scale.data<float>()[i];
2392
      }
2393
      VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
2394
    }
2395

2396
    auto outputWrapper = MPSImageWrapper(
2397
        this,
2398
        &inputWrapper,
2399
        X.numberOfImages,
2400
        X.height,
2401
        X.width,
2402
        X.featureChannels);
2403
    auto commandBuffer = outputWrapper.getCommandBuffer();
2404
    MPSImage* output = outputWrapper.getImage();
2405
    caffe2::Timer t;
2406
    id<MTLComputeCommandEncoder> encoder =
2407
        [commandBuffer computeCommandEncoder];
2408
    id<MTLComputePipelineState> state =
2409
        getMPSCNNContext().getSpecializedPipelineState(
2410
            kernelFor(output, @"affine", @"affine_nonarray"),
2411
            {ushort(X.featureChannels)});
2412

2413
    [encoder setComputePipelineState:state];
2414
    [encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
2415
    [encoder setBuffer:shiftBuffer_ offset:0 atIndex:1];
2416
    [encoder setTexture:[X texture] atIndex:0];
2417
    [encoder setTexture:[output texture] atIndex:1];
2418

2419
    const auto& launchParams =
2420
        spatialPointwiseKernelLaunchParams(state, output);
2421
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2422
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2423
    [encoder endEncoding];
2424
    inputWrapper.markRead();
2425
    VLOG(2) << "SpatialBN took: " << t.MilliSeconds();
2426
    outputWrapper.copyToOutputBlob(Outputs()[0]);
2427

2428
    return true;
2429
  }
2430

2431
 private:
2432
  id<MTLBuffer> scaleBuffer_;
2433
  id<MTLBuffer> shiftBuffer_;
2434
};
2435

2436
REGISTER_CPU_OPERATOR(MPSCNNSpatialBN, MPSCNNSpatialBNOp);
2437
OPERATOR_SCHEMA(MPSCNNSpatialBN).NumInputs(5).NumOutputs(1);
2438

2439
class MPSCNNConcatOp final : public Operator<CPUContext> {
2440
 public:
2441
  MPSCNNConcatOp(const OperatorDef& operator_def, Workspace* ws)
2442
      : Operator<CPUContext>(operator_def, ws) {
2443
    // Only handle three inputs for now.
2444
    OPERATOR_NEEDS_FEATURE(
2445
        Inputs().size() <= 4, "MPSCNNConcat only handles up to four inputs");
2446
  }
2447

2448
  bool RunOnDevice() override {
2449
    auto Wrapper = [&](size_t i) {
2450
      return Inputs()[i]->template Get<MPSImageWrapper>();
2451
    };
2452
    auto cb = [&](size_t i) { return Wrapper(i).getCommandBuffer(); };
2453
    auto X = [&](size_t i) { return Wrapper(i).getImage(); };
2454

2455
    // C0, C1, C2, C3, C, N
2456
    std::vector<ushort> channels = {
2457
        {0, 0, 0, 0, 0, ushort(X(0).numberOfImages)}};
2458
    size_t channelCount = 0;
2459
    for (auto i = 0; i < Inputs().size(); ++i) {
2460
      // this does not hold for non-temp images inputs
2461
      CAFFE_ENFORCE_EQ(cb(0), cb(i));
2462
      CAFFE_ENFORCE_EQ(X(0).height, X(i).height);
2463
      CAFFE_ENFORCE_EQ(X(0).width, X(i).width);
2464
      channels[i] = X(i).featureChannels;
2465
      channelCount += X(i).featureChannels;
2466
    }
2467
    channels[4] = channelCount;
2468

2469
    auto wrapper0 = Inputs()[0]->template Get<MPSImageWrapper>();
2470
    auto outputWrapper = MPSImageWrapper(
2471
        this,
2472
        &wrapper0,
2473
        X(0).numberOfImages,
2474
        X(0).height,
2475
        X(0).width,
2476
        channelCount);
2477
    auto commandBuffer = outputWrapper.getCommandBuffer();
2478
    MPSImage* output = outputWrapper.getImage();
2479
    caffe2::Timer t;
2480
    id<MTLComputeCommandEncoder> encoder =
2481
        [commandBuffer computeCommandEncoder];
2482
    id<MTLComputePipelineState> state =
2483
        getMPSCNNContext().getSpecializedPipelineState(@"concat", channels);
2484

2485
    [encoder setComputePipelineState:state];
2486
    for (auto i = 0; i < Inputs().size(); ++i) {
2487
      [encoder setTexture:[X(i) texture] atIndex:i];
2488
    }
2489
    [encoder setTexture:[output texture] atIndex:5];
2490
    const auto& launchParams =
2491
        spatialPointwiseKernelLaunchParams(state, output);
2492
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2493
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2494
    [encoder endEncoding];
2495
    for (auto i = 0; i < Inputs().size(); ++i) {
2496
      Wrapper(i).markRead();
2497
    }
2498

2499
    VLOG(2) << "Concat took: " << t.MilliSeconds();
2500
    outputWrapper.copyToOutputBlob(Outputs()[0]);
2501
    return true;
2502
  }
2503
};
2504

2505
REGISTER_CPU_OPERATOR(MPSCNNConcat, MPSCNNConcatOp);
2506
// Only store one output in practice (ignore the shape argument).
2507
OPERATOR_SCHEMA(MPSCNNConcat).NumInputs(2, 4).NumOutputs(1, 2);
2508

2509
class MPSCNNResizeNearestOp final : public Operator<CPUContext> {
2510
 public:
2511
  MPSCNNResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
2512
      : Operator<CPUContext>(operator_def, ws) {
2513
    width_scale_ = OperatorBase::GetSingleArgument<float>("width_scale", 1);
2514
    height_scale_ = OperatorBase::GetSingleArgument<float>("height_scale", 1);
2515
    CAFFE_ENFORCE_GT(width_scale_, 0);
2516
    CAFFE_ENFORCE_GT(height_scale_, 0);
2517

2518
    // due to the way we pass these parameters, we don't support the scale to be
2519
    // larger than 6.5
2520
    CAFFE_ENFORCE_LE(width_scale_, 6.5);
2521
    CAFFE_ENFORCE_LE(height_scale_, 6.5);
2522
  }
2523

2524
  bool RunOnDevice() override {
2525
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
2526
    const MPSImage* X = inputWrapper.getImage();
2527

2528
    const int N = X.numberOfImages, C = X.featureChannels, H = X.height,
2529
              W = X.width;
2530
    int output_width = W * width_scale_;
2531
    int output_height = H * height_scale_;
2532
    auto outputWrapper =
2533
        MPSImageWrapper(this, &inputWrapper, N, output_height, output_width, C);
2534
    auto commandBuffer = inputWrapper.getCommandBuffer();
2535
    MPSImage* output = outputWrapper.getImage();
2536

2537
    auto encoder = [commandBuffer computeCommandEncoder];
2538
    auto state = getMPSCNNContext().getSpecializedPipelineState(
2539
        kernelFor(output, @"resize_nearest", @"resize_nearest_nonarray"),
2540
        {{ushort(output_height),
2541
          ushort(output_width),
2542
          ushort(height_scale_ * 10000),
2543
          ushort(width_scale_ * 10000)}});
2544
    [encoder setComputePipelineState:state];
2545
    [encoder setTexture:[X texture] atIndex:0];
2546
    [encoder setTexture:[output texture] atIndex:1];
2547
    auto launchParams = spatialPointwiseKernelLaunchParams(state, output);
2548
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2549
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2550
    [encoder endEncoding];
2551
    inputWrapper.markRead();
2552
    outputWrapper.copyToOutputBlob(Outputs()[0]);
2553

2554
    return true;
2555
  }
2556

2557
 protected:
2558
  float width_scale_;
2559
  float height_scale_;
2560
};
2561

2562
REGISTER_CPU_OPERATOR(MPSCNNResizeNearest, MPSCNNResizeNearestOp);
2563
OPERATOR_SCHEMA(MPSCNNResizeNearest).NumInputs(1).NumOutputs(1);
2564

2565
class MPSCNNChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
2566
 public:
2567
  MPSCNNChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
2568
      : ConvPoolOpBase<CPUContext>(operator_def, ws) {
2569
    OPERATOR_NEEDS_FEATURE(
2570
        this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
2571
    kernel_[0] = kernel_[1] = 1;
2572
  }
2573

2574
  bool RunOnDeviceWithOrderNCHW() override {
2575
    caffe2::Timer t;
2576
    auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
2577
    MPSImage* X = inputWrapper.getImage();
2578
    CAFFE_ENFORCE_EQ(X.featureChannels % this->group_, 0);
2579
    auto output_height = X.height;
2580
    auto output_width = X.width;
2581
    auto outputWrapper = MPSImageWrapper(
2582
        this,
2583
        &inputWrapper,
2584
        X.numberOfImages,
2585
        output_height,
2586
        output_width,
2587
        X.featureChannels);
2588
    auto commandBuffer = outputWrapper.getCommandBuffer();
2589
    MPSImage* output = outputWrapper.getImage();
2590
    CAFFE_ENFORCE_EQ(output.height, output_height);
2591
    CAFFE_ENFORCE_EQ(output.width, output_width);
2592
    id<MTLComputeCommandEncoder> encoder =
2593
        [commandBuffer computeCommandEncoder];
2594
    id<MTLComputePipelineState> state =
2595
        getMPSCNNContext().getSpecializedPipelineState(
2596
            @"channel_shuffle",
2597
            {{
2598
                ushort(X.numberOfImages),
2599
                ushort(X.featureChannels),
2600
                ushort(X.featureChannels / this->group_),
2601
                ushort(this->group_),
2602
            }});
2603
    [encoder setComputePipelineState:state];
2604
    [encoder setTexture:[X texture] atIndex:0];
2605
    [encoder setTexture:[output texture] atIndex:1];
2606
    const auto& launchParams =
2607
        spatialPointwiseKernelLaunchParams(state, output);
2608
    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2609
            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2610
    [encoder endEncoding];
2611
    inputWrapper.markRead();
2612
    outputWrapper.copyToOutputBlob(Outputs()[0]);
2613

2614
    VLOG(2) << "ChannelShuffle took: " << t.MilliSeconds();
2615
    return true;
2616
  }
2617
};
2618

2619
REGISTER_CPU_OPERATOR(MPSCNNChannelShuffle, MPSCNNChannelShuffleOp);
2620
OPERATOR_SCHEMA(MPSCNNChannelShuffle).NumInputs(1).NumOutputs(1);
2621
}
2622

2623
CAFFE_KNOWN_TYPE(MPSImageWrapper);
2624
} // namespace caffe2
2625

2626
#endif
2627

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.