1
#include "caffe2/core/common.h"
2
#include "caffe2/core/context.h"
4
#if defined(CAFFE2_USE_MPSCNN) && defined(C10_MOBILE)
6
#include "caffe2/core/operator.h"
7
#include "caffe2/core/timer.h"
8
#include "caffe2/operators/conv_pool_op_base.h"
9
#include "caffe2/operators/conv_transpose_unpool_op_base.h"
10
#include "caffe2/operators/generate_proposals_op.h"
11
#include "caffe2/operators/generate_proposals_op_util_boxes.h"
12
#include "caffe2/operators/spatial_batch_norm_op.h"
15
#include "mpscnn_context.h"
17
#import <Metal/Metal.h>
18
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
19
#import <UIKit/UIDevice.h>
21
#define SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(v) \
22
([[[UIDevice currentDevice] systemVersion] \
24
options:NSNumericSearch] != NSOrderedAscending)
26
// Only compiles against Base SDK iOS 11.0 or greater
27
@interface ConvDataSource : NSObject<MPSCNNConvolutionDataSource>
28
@property float* weights_;
29
@property float* bias_;
30
@property MPSCNNConvolutionDescriptor* desc_;
33
@implementation ConvDataSource
34
- (id)initWithWeight:(float*)weights
36
desc:(MPSCNNConvolutionDescriptor*)desc {
38
self.weights_ = weights;
47
- (MPSDataType)dataType {
48
return MPSDataTypeFloat32;
50
- (MPSCNNConvolutionDescriptor*)descriptor {
59
- (float*)lookupTableForUInt8Kernel {
65
- (vector_float2*)rangesForUInt8Kernel {
72
- (id)copyWithZone:(NSZone*)zone {
73
ConvDataSource* newDataSource = [[self class] allocWithZone:zone];
74
newDataSource.weights_ = self.weights_;
75
newDataSource.bias_ = self.bias_;
76
newDataSource.desc_ = self.desc_;
84
auto divRoundUp(uint x, uint y) -> uint {
85
return (x + y - 1) / y;
88
MPSTemporaryImage* createTemporaryImage(
89
const OperatorBase* op,
90
id<MTLCommandBuffer> commandBuffer,
95
size_t output_idx = 0) {
96
auto* image = [MPSTemporaryImage
97
temporaryImageWithCommandBuffer:commandBuffer
100
imageDescriptorWithChannelFormat:
101
MPSImageFeatureChannelFormatFloat16
104
featureChannels:channels
107
MTLTextureUsageShaderRead |
108
MTLTextureUsageShaderWrite]];
109
// We'll try to look at the per-output_idx read-count argument, otherwise,
110
// we'll use the operator-global default.
111
const auto& readCounts = op->GetRepeatedArgument<int>(kMPSCNNReadCountArg);
112
const auto readCount = readCounts.size()
113
? readCounts.at(output_idx)
114
: op->GetSingleArgument<int>(kMPSCNNReadCountArg, 1);
115
CAFFE_ENFORCE_GE(readCount, 1);
116
image.readCount = readCount;
120
MPSImage* createStaticImage(int n, int height, int width, int channels) {
121
return [[MPSImage alloc]
122
initWithDevice:getMPSCNNContext().device
125
imageDescriptorWithChannelFormat:
126
MPSImageFeatureChannelFormatFloat16
129
featureChannels:channels
131
usage:MTLTextureUsageShaderRead |
132
MTLTextureUsageShaderWrite]];
135
class MPSImageWrapper {
139
const OperatorBase* op,
140
MPSImageWrapper* parent,
145
size_t output_idx = 0) {
146
/* If the parent wrapper contains a temporary image, we need to pass on the
147
* command buffer because the temporary images are attached to the command
148
* buffer, we will need to use the same command buffer in order to use the
149
* temporary image. We don't want to synchronize the parent wrapper because
150
* it is still in use. If the parent wrapper contains a static image, we
151
* should create a new command buffer because we use static image so it can
152
* survive synchronization(commit of the command buffer), which means if we
153
* pass on the command buffer the command buffer will be committed in
154
* multiple places in the graph. Also since we don't pass on parent's
155
* command buffer,we need to synchronize(commit) it since it won't be used
158
bool passOnCb = parent != nullptr && parent->isTemporaryImage_;
159
commandBuffer_ = passOnCb ? parent->commandBuffer_
160
: [getMPSCNNContext().commandQueue commandBuffer];
162
bool commitInputCb = parent != nullptr && !parent->isTemporaryImage_;
164
parent->synchronize();
167
const auto& isTemporaryImages =
168
op->GetRepeatedArgument<int>(kMPSCNNOutputIsTempImageArg);
169
isTemporaryImage_ = isTemporaryImages.size()
170
? isTemporaryImages.at(output_idx)
171
: op->GetSingleArgument<int>(kMPSCNNOutputIsTempImageArg, 1);
172
if (isTemporaryImage_) {
173
image_ = createTemporaryImage(
174
op, commandBuffer_, n, height, width, channels, output_idx);
176
image_ = createStaticImage(n, height, width, channels);
181
if (isTemporaryImage_) {
182
MPSTemporaryImage* tempImg = (MPSTemporaryImage*)image_;
183
tempImg.readCount -= 1;
187
MPSImage* getImage() const {
191
id<MTLCommandBuffer> getCommandBuffer() const {
192
return commandBuffer_;
196
// commit the command buffer if it is notEnqueued
197
if (commandBuffer_ != nullptr && commandBuffer_.status == 0) {
198
[commandBuffer_ commit];
207
void copyToOutputBlob(Blob* output) {
208
output->GetMutable<MPSImageWrapper>()->image_ = image_;
209
output->GetMutable<MPSImageWrapper>()->commandBuffer_ = commandBuffer_;
210
output->GetMutable<MPSImageWrapper>()->isTemporaryImage_ =
215
MPSImage* image_{nullptr};
216
id<MTLCommandBuffer> commandBuffer_{nullptr};
217
bool isTemporaryImage_ = true;
221
kernelFor(const MPSImage* X, NSString* arrayKernel, NSString* nonArrayKernel) {
222
if (X.featureChannels > 4) {
225
if (X.numberOfImages > 1) {
228
return nonArrayKernel;
232
MTLSize threadsPerThreadgroup;
233
MTLSize threadgroupsPerGrid;
236
LaunchParams spatialPointwiseKernelLaunchParams(
237
id<MTLComputePipelineState> pipeline,
238
const MPSImage* im) {
239
const auto maxThreadsPerThreadgroup =
240
[pipeline maxTotalThreadsPerThreadgroup];
241
const auto threadExecutionWidth = [pipeline threadExecutionWidth];
242
const auto threadsPerThreadgroup = MTLSizeMake(
243
8 /* threadExecutionWidth */,
244
4 /* maxThreadsPerThreadgroup / threadExecutionWidth */,
246
const auto threadgroupsPerGrid = MTLSizeMake(
247
divRoundUp(im.width, threadsPerThreadgroup.width),
248
divRoundUp(im.height, threadsPerThreadgroup.height),
249
im.numberOfImages * divRoundUp(im.featureChannels, 4));
250
return {threadsPerThreadgroup, threadgroupsPerGrid};
254
ConvPoolOpBase<CPUContext>* op,
259
Tensor input = caffe2::empty({1, 1, H, W}, at::dtype<float>().device(CPU));
260
auto sizes = op->GetOutputSize(input, 1);
261
CAFFE_ENFORCE_EQ(sizes.size(), 4);
266
constexpr int computeMPSAlignOffset(int kernel, int pad) {
267
// To set the offset, we can just match the top-left pixel (in the input
268
// image, with negative values for padding) that we look at. For 3x3s1p1, we
269
// look at the (-1, -1) pixel in the original impl. For 3x3s1p0, we look at
270
// (0, 0) pixel. For 3x3s1p2, look at (-2, -2) MPSCNN always looks at
271
// (-floor(kernel_size - 1 / 2), -floor(kernel_size - 1 / 2)) Thus, we just
272
// need to match this up.
274
// For 3x3s1p1, offset should be (0, 0)
275
// For 3x3s1p0, offset should be (1, 1)
276
// For 3x3s1p2, offset should be (-1, -1)
277
const int mps_offset = kernel / 2;
278
const int c2_offset = pad;
279
return mps_offset - c2_offset;
282
// Compute the 1-d index of a n-dimensional contiguous row-major tensor for
283
// a given n-dimensional index 'index'
284
size_t ComputeStartIndex(
285
const TensorCPU& tensor,
286
const std::vector<int>& index) {
287
TORCH_DCHECK_EQ(index.size(), tensor.dim());
290
for (int i = 0; i < index.size(); i++) {
291
ret += index[i] * tensor.size_from_dim(i + 1);
297
// Get a sub tensor view from 'tensor' using data pointer from 'tensor'
299
utils::ConstTensorView<T> GetSubTensorView(
300
const TensorCPU& tensor,
301
int dim0_start_index) {
302
TORCH_DCHECK_EQ(tensor.meta().itemsize(), sizeof(T));
304
if (tensor.size() == 0) {
305
return utils::ConstTensorView<T>(nullptr, {});
308
std::vector<int> start_dims(tensor.dim(), 0);
309
start_dims.at(0) = dim0_start_index;
310
auto st_idx = ComputeStartIndex(tensor, start_dims);
311
auto ptr = tensor.data<T>() + st_idx;
313
auto input_dims = tensor.sizes();
314
std::vector<int> ret_dims(input_dims.begin() + 1, input_dims.end());
316
utils::ConstTensorView<T> ret(ptr, ret_dims);
320
class CopyToMPSCNNOp final : public Operator<CPUContext> {
322
CopyToMPSCNNOp(const OperatorDef& operator_def, Workspace* ws)
323
: Operator<CPUContext>(operator_def, ws) {}
325
bool RunOnDevice() override {
326
inputBuffers_.resize(Inputs().size());
327
std::vector<MPSImageWrapper> wrappers(Inputs().size());
328
for (auto i = 0; i < Inputs().size(); ++i) {
329
const auto& X = Input(i);
330
CAFFE_ENFORCE(X.dim() > 0 && X.dim() <= 4);
331
std::vector<int64_t> XDims = {1, 1, 1, 1};
332
XDims.assign(X.sizes().begin(), X.sizes().end());
335
const auto n = XDims[0];
336
const auto width = XDims[3];
337
const auto height = XDims[2];
338
const auto channels = XDims[1];
340
if (!inputBuffers_[i] || inputBuffers_[i].length != X.nbytes()) {
341
inputBuffers_[i] = [getMPSCNNContext().device
342
newBufferWithLength:X.nbytes()
343
options:MTLResourceOptionCPUCacheModeWriteCombined];
345
memcpy([inputBuffers_[i] contents], X.raw_data(), X.nbytes());
346
VLOG(2) << "CopyToMPSCNNOp input copy took: " << copyT.MilliSeconds();
349
MPSImageWrapper(this, nullptr, n, height, width, channels, i);
352
MPSImageWrapper(this, &wrappers[0], n, height, width, channels, i);
354
auto commandBuffer = wrappers[i].getCommandBuffer();
355
MPSImage* output = wrappers[i].getImage();
356
id<MTLComputeCommandEncoder> encoder =
357
[commandBuffer computeCommandEncoder];
358
id<MTLComputePipelineState> state =
359
getMPSCNNContext().getSpecializedPipelineState(
362
@"copy_nchw_to_metal",
363
@"copy_nchw_to_metal_nonarray"),
364
{{ushort(channels), ushort(height), ushort(width)}});
365
[encoder setComputePipelineState:state];
366
[encoder setBuffer:inputBuffers_[i] offset:0 atIndex:0];
367
[encoder setTexture:[output texture] atIndex:0];
368
const auto& launchParams =
369
spatialPointwiseKernelLaunchParams(state, output);
370
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
371
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
372
[encoder endEncoding];
373
VLOG(2) << "CopyToMPSCNNOp took: " << t.MilliSeconds();
374
wrappers[i].copyToOutputBlob(Outputs()[i]);
380
std::vector<id<MTLBuffer>> inputBuffers_;
383
REGISTER_CPU_OPERATOR(CopyToMPSCNN, CopyToMPSCNNOp);
384
OPERATOR_SCHEMA(CopyToMPSCNN)
385
.NumInputs(1, INT_MAX)
386
.NumOutputs(1, INT_MAX)
387
.SameNumberOfOutput();
389
auto mpsImageSize = [](MPSImage* X) {
390
return X.featureChannels * X.height * X.width * X.numberOfImages;
393
class CopyFromMPSCNNOp final : public Operator<CPUContext> {
395
CopyFromMPSCNNOp(const OperatorDef& operator_def, Workspace* ws)
396
: Operator<CPUContext>(operator_def, ws) {}
398
bool RunOnDevice() override {
400
auto Wrapper = [&](size_t i) {
401
return Inputs()[i]->template Get<MPSImageWrapper>();
403
auto cb = [&](size_t i) { return Wrapper(i).getCommandBuffer(); };
404
auto X = [&](size_t i) { return Wrapper(i).getImage(); };
407
outputBuffers_.resize(Inputs().size());
408
for (auto i = 0; i < Inputs().size(); ++i) {
409
CAFFE_ENFORCE_EQ(cb0, cb(i));
411
if (!outputBuffers_[i] ||
412
outputBuffers_[i].length != mpsImageSize(Xi) * sizeof(float)) {
413
outputBuffers_[i] = [getMPSCNNContext().device
414
newBufferWithLength:mpsImageSize(Xi) * sizeof(float)
415
options:MTLResourceOptionCPUCacheModeDefault];
417
id<MTLComputeCommandEncoder> encoder = [cb0 computeCommandEncoder];
418
id<MTLComputePipelineState> state =
419
getMPSCNNContext().getSpecializedPipelineState(
421
Xi, @"copy_metal_to_nchw", @"copy_metal_to_nchw_nonarray"),
422
{{ushort(Xi.featureChannels),
426
[encoder setComputePipelineState:state];
427
[encoder setBuffer:outputBuffers_[i] offset:0 atIndex:0];
428
[encoder setTexture:[Xi texture] atIndex:0];
430
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Xi);
431
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
432
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
433
[encoder endEncoding];
434
Wrapper(i).markRead();
437
[cb0 waitUntilCompleted];
439
for (auto i = 0; i < Inputs().size(); ++i) {
440
caffe2::Timer copyOutT;
443
Xi.numberOfImages, Xi.featureChannels, Xi.height, Xi.width);
444
Output(i)->mutable_data<float>();
445
CAFFE_ENFORCE_EQ(outputBuffers_[i].length, Output(i)->nbytes());
447
Output(i)->mutable_data<float>(),
448
[outputBuffers_[i] contents],
449
outputBuffers_[i].length);
450
VLOG(2) << "CopyFromMPSCNNOp memcpy took: " << copyOutT.MilliSeconds();
452
VLOG(2) << "CopyFromMPSCNNOp took: " << t.MilliSeconds();
457
std::vector<id<MTLBuffer>> outputBuffers_;
460
REGISTER_CPU_OPERATOR(CopyFromMPSCNN, CopyFromMPSCNNOp);
461
OPERATOR_SCHEMA(CopyFromMPSCNN)
462
.NumInputs(1, INT_MAX)
463
.NumOutputs(1, INT_MAX)
464
.SameNumberOfOutput();
466
class MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp final
467
: public Operator<CPUContext> {
469
MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp(
470
const OperatorDef& operator_def,
472
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
474
bool RunOnDevice() override {
475
const auto& X = Input(0);
476
const auto& mean = Input(1);
477
CAFFE_ENFORCE_EQ(mean.size(), 3);
478
CAFFE_ENFORCE_EQ(X.dim(), 4);
479
CAFFE_ENFORCE_EQ(X.size(0), 1);
480
CAFFE_ENFORCE_EQ(X.size(3), 4);
481
const auto H = X.size(1);
482
const auto W = X.size(2);
486
auto* noiseBlob = ws_->CreateBlob("__CAFFE2_STYLIZER_NOISE__");
487
ushort noiseSize = OperatorBase::GetSingleArgument<int>(
488
"noise_size", 491 /* prime to avoid artifacts */);
489
// Treaded as half4 in the kernel, so need half4 here.
490
noiseSize = divRoundUp(noiseSize, 4) * 4;
491
if (!BlobIsTensorType(*noiseBlob, CPU) ||
492
noiseBlob->Get<TensorCPU>().size() != noiseSize) {
493
VLOG(2) << "Initializing stylizer with noise: " << noiseSize;
495
// Initialize random noise on first use.
496
// Cache it to maintain temporal consistency.
497
auto* t = BlobGetMutableTensor(noiseBlob, CPU);
498
t->Resize(noiseSize);
499
math::RandGaussian<float, CPUContext>(
502
OperatorBase::GetSingleArgument<float>("noise_std", 10.0),
503
t->template mutable_data<float>(),
505
VLOG(2) << "Preprocess initializing noise: " << rt.MilliSeconds();
507
const auto& noise = noiseBlob->Get<TensorCPU>();
509
if (!inputBuffer_ || inputBuffer_.length != X.nbytes()) {
512
inputBuffer_ = [getMPSCNNContext().device
513
newBufferWithLength:X.nbytes()
514
options:MTLResourceOptionCPUCacheModeWriteCombined];
515
meanBuffer_ = [getMPSCNNContext().device
516
newBufferWithLength:4 * 2 // (3/4 half-floats).
517
options:MTLResourceOptionCPUCacheModeWriteCombined];
518
noiseBuffer_ = [getMPSCNNContext().device
519
newBufferWithLength:noiseSize * sizeof(float16_t)
520
options:MTLResourceOptionCPUCacheModeWriteCombined];
522
float16_t* meanBufferPtr = (float16_t*)[meanBuffer_ contents];
523
CAFFE_ENFORCE(meanBufferPtr);
524
for (auto i = 0; i < mean.size(); ++i) {
525
meanBufferPtr[i] = mean.data<float>()[i];
527
float16_t* noiseBufferPtr = (float16_t*)[noiseBuffer_ contents];
528
CAFFE_ENFORCE(noiseBufferPtr);
529
for (auto i = 0; i < noise.size(); ++i) {
530
noiseBufferPtr[i] = noise.data<float>()[i];
533
VLOG(2) << "Preprocess construct took: " << pt.MilliSeconds();
538
memcpy([inputBuffer_ contents], X.raw_data(), X.nbytes());
539
VLOG(2) << "Preprocess memcpy took: " << ct.MilliSeconds();
541
auto outputWrapper = MPSImageWrapper(this, nullptr, 1, H, W, 3);
542
auto commandBuffer = outputWrapper.getCommandBuffer();
543
MPSImage* output = outputWrapper.getImage();
545
id<MTLComputeCommandEncoder> encoder =
546
[commandBuffer computeCommandEncoder];
547
id<MTLComputePipelineState> state =
548
getMPSCNNContext().getSpecializedPipelineState(
549
@"preprocess_stylizer", {noiseSize});
551
[encoder setComputePipelineState:state];
552
[encoder setBuffer:inputBuffer_ offset:0 atIndex:0];
553
[encoder setBuffer:meanBuffer_ offset:0 atIndex:1];
554
[encoder setBuffer:noiseBuffer_ offset:0 atIndex:2];
556
[encoder setTexture:[output texture] atIndex:0];
557
const auto& launchParams =
558
spatialPointwiseKernelLaunchParams(state, output);
559
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
560
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
561
[encoder endEncoding];
562
outputWrapper.copyToOutputBlob(Outputs()[0]);
564
VLOG(2) << "Preprocess took: " << t.MilliSeconds();
569
Workspace* ws_{nullptr};
570
id<MTLBuffer> inputBuffer_{nullptr};
571
id<MTLBuffer> noiseBuffer_{nullptr};
572
id<MTLBuffer> meanBuffer_{nullptr};
575
REGISTER_CPU_OPERATOR(
576
MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess,
577
MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp);
578
OPERATOR_SCHEMA(MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess)
582
class MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocessOp final
583
: public Operator<CPUContext> {
585
MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocessOp(
586
const OperatorDef& operator_def,
588
: Operator<CPUContext>(operator_def, ws) {}
590
bool RunOnDevice() override {
591
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
592
MPSImage* X = inputWrapper.getImage();
593
id<MTLCommandBuffer> commandBuffer = inputWrapper.getCommandBuffer();
595
const auto& mean = Input(1);
597
const auto W = X.width;
598
const auto H = X.height;
599
CAFFE_ENFORCE_EQ(X.featureChannels, 3);
600
CAFFE_ENFORCE_EQ(X.numberOfImages, 1);
602
if (!outputBuffer_ || outputBuffer_.length != X.height * X.width * 4) {
605
outputBuffer_ = [getMPSCNNContext().device
606
newBufferWithLength:X.height * X.width * 4
607
options:MTLResourceOptionCPUCacheModeDefault];
608
meanBuffer_ = [getMPSCNNContext().device
609
newBufferWithLength:4 * 2 // (3/4 half-floats).
610
options:MTLResourceOptionCPUCacheModeWriteCombined];
611
float16_t* meanBufferPtr = (float16_t*)[meanBuffer_ contents];
612
for (auto i = 0; i < mean.size(); ++i) {
613
meanBufferPtr[i] = mean.data<float>()[i];
615
VLOG(2) << "Deprocess copy took: " << pt.MilliSeconds();
617
id<MTLComputeCommandEncoder> encoder =
618
[commandBuffer computeCommandEncoder];
619
id<MTLComputePipelineState> state =
620
getMPSCNNContext().getPipelineState(@"deprocess_stylizer");
622
CAFFE_ENFORCE_EQ(outputBuffer_.length, X.height * X.width * 4);
623
[encoder setComputePipelineState:state];
624
[encoder setBuffer:outputBuffer_ offset:0 atIndex:0];
625
[encoder setBuffer:meanBuffer_ offset:0 atIndex:1];
626
[encoder setTexture:[X texture] atIndex:0];
627
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, X);
628
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
629
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
630
[encoder endEncoding];
631
inputWrapper.markRead();
633
[commandBuffer commit];
634
[commandBuffer waitUntilCompleted];
636
Output(0)->Resize(1, X.height, X.width, 4);
640
Output(0)->mutable_data<uint8_t>(),
641
[outputBuffer_ contents],
642
[outputBuffer_ length]);
643
VLOG(2) << "Deprocess copy: " << t.MilliSeconds();
645
CAFFE_ENFORCE_EQ(Output(0)->nbytes(), [outputBuffer_ length]);
646
VLOG(2) << "Deprocess took: " << t.MilliSeconds();
652
id<MTLBuffer> outputBuffer_{nullptr};
653
id<MTLBuffer> meanBuffer_{nullptr};
656
REGISTER_CPU_OPERATOR(
657
MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess,
658
MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocessOp);
659
OPERATOR_SCHEMA(MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess)
663
template <typename Neuron>
664
class MPSCNNNeuronOp final : public Operator<CPUContext> {
666
MPSCNNNeuronOp(const OperatorDef& operator_def, Workspace* ws)
667
: Operator<CPUContext>(operator_def, ws) {}
669
bool RunOnDevice() override {
671
auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
672
MPSImage* X = inputWrapper.getImage();
674
auto outputWrapper = MPSImageWrapper(
681
auto commandBuffer = outputWrapper.getCommandBuffer();
682
MPSImage* output = outputWrapper.getImage();
683
CAFFE_ENFORCE_EQ(output.width, X.width);
684
CAFFE_ENFORCE_EQ(output.height, X.height);
685
CAFFE_ENFORCE_EQ(output.featureChannels, X.featureChannels);
688
neuron_ = Neuron::t();
690
[neuron_ encodeToCommandBuffer:commandBuffer
692
destinationImage:output];
693
outputWrapper.copyToOutputBlob(Outputs()[0]);
695
VLOG(2) << "ElementwiseAdd took: " << t.MilliSeconds();
698
MPSCNNNeuron* neuron_{nullptr};
701
#define INIT_NEURON_OP(n) \
702
REGISTER_CPU_OPERATOR(MPSCNN##n, MPSCNNNeuronOp<n##NeuronInit>); \
703
OPERATOR_SCHEMA(MPSCNN##n).NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
705
struct SigmoidNeuronInit {
706
static MPSCNNNeuron* t() {
708
[[MPSCNNNeuronSigmoid alloc] initWithDevice:getMPSCNNContext().device];
711
INIT_NEURON_OP(Sigmoid);
713
struct ReluNeuronInit {
714
static MPSCNNNeuron* t() {
716
[[MPSCNNNeuronReLU alloc] initWithDevice:getMPSCNNContext().device a:0];
721
struct TanhNeuronInit {
722
static MPSCNNNeuron* t() {
723
return [[MPSCNNNeuronTanH alloc] initWithDevice:getMPSCNNContext().device
732
template <typename Neuron>
733
class MPSCNNConvOp final : public ConvPoolOpBase<CPUContext> {
735
MPSCNNConvOp(const OperatorDef& operator_def, Workspace* ws)
736
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
737
OPERATOR_NEEDS_FEATURE(
738
this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
739
OPERATOR_NEEDS_FEATURE(
740
kernel_h() == kernel_w(),
741
"Metal only supports equal kernel dimension.");
744
bool RunOnDeviceWithOrderNCHW() override {
746
auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
747
MPSImage* X = inputWrapper.getImage();
749
auto& filter = Input(FILTER);
750
auto& bias = Input(BIAS);
751
CAFFE_ENFORCE_EQ(filter.dim(), 4);
752
// For NCHW, X.dim32(1), inputChannels
753
const int C = X.featureChannels;
754
const int M = filter.dim32(0);
755
const int Cf = filter.dim32(1);
757
CAFFE_ENFORCE(filter.dim32(2) == kernel_h(), "");
758
CAFFE_ENFORCE(filter.dim32(3) == kernel_w(), "");
759
CAFFE_ENFORCE(bias.dim() == 1, "");
760
CAFFE_ENFORCE(bias.dim32(0) == M, "");
762
const auto kH = kernel_h();
763
const auto kW = kernel_w();
765
// ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
766
// Reformat weights from [M][C][kH][kW] to [M][kH][kW][C].
769
std::vector<float> refilter(M * kH * kW * Cf);
770
auto* filter_ = filter.template data<float>();
771
for (auto m = 0; m < M; ++m) {
772
for (auto c = 0; c < Cf; ++c) {
773
for (auto kh = 0; kh < kH; ++kh) {
774
for (auto kw = 0; kw < kW; ++kw) {
775
// refilter[m][kh][kw][c]
776
refilter[m * kH * kW * Cf + kh * kW * Cf + kw * Cf + c] =
777
// filter[m][c][kh][kw]
778
filter_[m * Cf * kH * kW + c * kH * kW + kh * kW + kw];
783
// DepthwiseConv path
784
bool runtimeAtLeastIOS11 =
785
SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"11.0");
786
// Only inputFeatureChannels == outputFeatureChannels is supported right
788
if (runtimeAtLeastIOS11 && this->group_ > 1 && Cf == 1 &&
790
MPSCNNDepthWiseConvolutionDescriptor* desc =
791
[MPSCNNDepthWiseConvolutionDescriptor
792
cnnConvolutionDescriptorWithKernelWidth:kW
794
inputFeatureChannels:C
795
outputFeatureChannels:M
796
neuronFilter:Neuron::t()];
797
desc.strideInPixelsX = stride_w();
798
desc.strideInPixelsY = stride_h();
800
auto data_source = [[ConvDataSource alloc]
801
initWithWeight:refilter.data()
802
bias:const_cast<float*>(bias.template data<float>())
805
[[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
806
weights:data_source];
808
if (this->group_ > 1) {
812
"MPSCNNConvolution requires number of input \
813
channels in each group to be multiple of 4 for \
816
MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
817
cnnConvolutionDescriptorWithKernelWidth:kW
819
inputFeatureChannels:C
820
outputFeatureChannels:M
821
neuronFilter:Neuron::t()];
822
desc.strideInPixelsX = stride_w();
823
desc.strideInPixelsY = stride_h();
824
desc.groups = this->group_;
825
auto data_source = [[ConvDataSource alloc]
826
initWithWeight:refilter.data()
827
bias:const_cast<float*>(bias.template data<float>())
830
[[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
831
weights:data_source];
834
[conv_ setEdgeMode:MPSImageEdgeModeZero];
837
offset.x = computeMPSAlignOffset(kW, pad_l());
838
offset.y = computeMPSAlignOffset(kH, pad_t());
840
[conv_ setOffset:offset];
841
VLOG(2) << "MPSCNNConv ConvDesc took: " << consT.MilliSeconds();
844
CAFFE_ENFORCE_EQ(conv_.strideInPixelsY, stride_h());
845
CAFFE_ENFORCE_EQ(conv_.strideInPixelsX, stride_w());
846
CAFFE_ENFORCE_EQ(conv_.inputFeatureChannels, Cf * this->group_);
847
CAFFE_ENFORCE_EQ(M % conv_.groups, 0);
848
CAFFE_ENFORCE_EQ(conv_.outputFeatureChannels, M);
849
CAFFE_ENFORCE_EQ(conv_.kernelWidth, kW);
850
CAFFE_ENFORCE_EQ(conv_.kernelHeight, kH);
854
computeOutputHW(this, X.height, X.width, &output_height, &output_width);
855
int output_channels = M;
857
VLOG(2) << "Output height: " << output_height;
858
VLOG(2) << "Output width:" << output_width;
859
VLOG(2) << "Output channels:" << output_channels;
860
auto outputWrapper = MPSImageWrapper(
867
auto commandBuffer = outputWrapper.getCommandBuffer();
868
MPSImage* output = outputWrapper.getImage();
869
CAFFE_ENFORCE_EQ(output.height, output_height);
870
CAFFE_ENFORCE_EQ(output.width, output_width);
871
[conv_ encodeToCommandBuffer:commandBuffer
873
destinationImage:output];
874
outputWrapper.copyToOutputBlob(Outputs()[0]);
876
VLOG(2) << "MPSCNNConv took: " << t.MilliSeconds();
882
INPUT_TAGS(INPUT, FILTER, BIAS);
884
MPSCNNConvolution* conv_{nullptr};
888
struct EmptyNeuronInit {
889
static MPSCNNNeuron* t() {
894
// We can allow the input weights/bias and output to alias each other,
895
// for example when doing a Conv + out-of-place ReLU, then fusing.
896
#define INIT_CONV_NEURON_OP(name, neuron) \
897
REGISTER_CPU_OPERATOR(name, MPSCNNConvOp<neuron>); \
898
OPERATOR_SCHEMA(name).NumInputs(3).NumOutputs(1).AllowInplace( \
901
INIT_CONV_NEURON_OP(MPSCNNConv, EmptyNeuronInit);
902
INIT_CONV_NEURON_OP(MPSCNNConvRelu, ReluNeuronInit);
903
INIT_CONV_NEURON_OP(MPSCNNConvSigmoid, SigmoidNeuronInit);
905
#undef INIT_CONV_NEURON_OP
907
class MPSCNNPadImageOp final : public ConvPoolOpBase<CPUContext> {
909
MPSCNNPadImageOp(const OperatorDef& operator_def, Workspace* ws)
910
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
911
OPERATOR_NEEDS_FEATURE(
912
this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
914
OPERATOR_NEEDS_FEATURE(
915
OperatorBase::GetSingleArgument<string>("mode", "") == "reflect",
916
"Metal only supports reflection");
917
kernel_[0] = kernel_[1] = 1;
920
bool RunOnDeviceWithOrderNCHW() override {
922
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
923
MPSImage* X = inputWrapper.getImage();
925
const auto pH = pad_t();
926
const auto pW = pad_l();
927
const auto output_height = X.height + 2 * pH;
928
const auto output_width = X.width + 2 * pW;
929
VLOG(1) << "Output height: " << output_height;
930
VLOG(1) << "Output width:" << output_width;
931
VLOG(2) << "Output channels:" << X.featureChannels;
932
auto outputWrapper = MPSImageWrapper(
939
auto commandBuffer = outputWrapper.getCommandBuffer();
940
MPSImage* output = outputWrapper.getImage();
941
CAFFE_ENFORCE_EQ(output.height, output_height);
942
CAFFE_ENFORCE_EQ(output.width, output_width);
943
id<MTLComputeCommandEncoder> encoder =
944
[commandBuffer computeCommandEncoder];
945
id<MTLComputePipelineState> state =
946
getMPSCNNContext().getPipelineState(kernelFor(
947
output, @"reflection_padding", @"reflection_padding_nonarray"));
948
[encoder setComputePipelineState:state];
949
[encoder setTexture:[X texture] atIndex:0];
950
[encoder setTexture:[output texture] atIndex:1];
951
const auto& launchParams =
952
spatialPointwiseKernelLaunchParams(state, output);
953
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
954
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
955
[encoder endEncoding];
956
inputWrapper.markRead();
957
outputWrapper.copyToOutputBlob(Outputs()[0]);
959
VLOG(2) << "PadImage took: " << t.MilliSeconds();
964
REGISTER_CPU_OPERATOR(MPSCNNPadImage, MPSCNNPadImageOp);
965
OPERATOR_SCHEMA(MPSCNNPadImage).NumInputs(1).NumOutputs(1);
967
class MPSCNNMulOp final : public Operator<CPUContext> {
969
MPSCNNMulOp(const OperatorDef& operator_def, Workspace* ws)
970
: Operator<CPUContext>(operator_def, ws) {
971
OPERATOR_NEEDS_FEATURE(
972
OperatorBase::GetSingleArgument<int>("broadcast", 0) == 1,
973
"MPSCNNMul only supports broadcast");
975
OPERATOR_NEEDS_FEATURE(
976
OperatorBase::HasArgument("axis") == false,
977
"MPSCNNMul does not support axis");
980
bool RunOnDevice() override {
983
auto wrapper0 = Inputs()[0]->Get<MPSImageWrapper>();
984
MPSImage* X0 = wrapper0.getImage();
986
const auto& X1 = Input(1);
990
"MPSCNNMulOp: Only dim == 1 for Input(1) is supported for now");
992
auto X1_ = [getMPSCNNContext().device
993
newBufferWithBytes:X1.template data<float>()
994
length:sizeof(float) * X1.size()
995
options:MTLResourceOptionCPUCacheModeDefault];
997
auto outputWrapper = MPSImageWrapper(
1003
X0.featureChannels);
1004
auto commandBuffer = outputWrapper.getCommandBuffer();
1005
MPSImage* output = outputWrapper.getImage();
1007
id<MTLComputeCommandEncoder> encoder =
1008
[commandBuffer computeCommandEncoder];
1009
id<MTLComputePipelineState> state =
1010
getMPSCNNContext().getSpecializedPipelineState(
1012
{{ushort(X0.numberOfImages),
1013
ushort(X0.featureChannels),
1014
ushort(X1.dim32(0))}});
1016
[encoder setComputePipelineState:state];
1017
[encoder setTexture:[X0 texture] atIndex:0];
1018
[encoder setBuffer:X1_ offset:0 atIndex:1];
1019
[encoder setTexture:[output texture] atIndex:2];
1020
const auto& launchParams =
1021
spatialPointwiseKernelLaunchParams(state, output);
1022
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1023
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1024
[encoder endEncoding];
1025
wrapper0.markRead();
1026
outputWrapper.copyToOutputBlob(Outputs()[0]);
1027
VLOG(2) << "ElementwiseMul took: " << t.MilliSeconds();
1032
REGISTER_CPU_OPERATOR(MPSCNNMul, MPSCNNMulOp);
1033
OPERATOR_SCHEMA(MPSCNNMul).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1035
class MPSCNNSubOp final : public Operator<CPUContext> {
1037
MPSCNNSubOp(const OperatorDef& operator_def, Workspace* ws)
1038
: Operator<CPUContext>(operator_def, ws) {
1039
OPERATOR_NEEDS_FEATURE(
1040
OperatorBase::GetSingleArgument<int>("broadcast", 0) == 1,
1041
"MPSCNNSub only supports broadcast");
1043
OPERATOR_NEEDS_FEATURE(
1044
OperatorBase::HasArgument("axis") == false,
1045
"MPSCNNSub does not support axis");
1048
bool RunOnDevice() override {
1051
auto wrapper0 = Inputs()[0]->Get<MPSImageWrapper>();
1052
MPSImage* X0 = wrapper0.getImage();
1054
const auto& X1 = Input(1);
1058
"MPSCNNSubOp: Only dim == 1 for Input(1) is supported for now");
1060
auto X1_ = [getMPSCNNContext().device
1061
newBufferWithBytes:X1.template data<float>()
1062
length:sizeof(float) * X1.size()
1063
options:MTLResourceOptionCPUCacheModeDefault];
1065
auto outputWrapper = MPSImageWrapper(
1071
X0.featureChannels);
1072
auto commandBuffer = outputWrapper.getCommandBuffer();
1073
MPSImage* output = outputWrapper.getImage();
1075
id<MTLComputeCommandEncoder> encoder =
1076
[commandBuffer computeCommandEncoder];
1077
id<MTLComputePipelineState> state =
1078
getMPSCNNContext().getSpecializedPipelineState(
1080
{{ushort(X0.numberOfImages),
1081
ushort(X0.featureChannels),
1082
ushort(X1.dim32(0))}});
1084
[encoder setComputePipelineState:state];
1085
[encoder setTexture:[X0 texture] atIndex:0];
1086
[encoder setBuffer:X1_ offset:0 atIndex:1];
1087
[encoder setTexture:[output texture] atIndex:2];
1088
const auto& launchParams =
1089
spatialPointwiseKernelLaunchParams(state, output);
1090
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1091
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1092
[encoder endEncoding];
1093
wrapper0.markRead();
1094
outputWrapper.copyToOutputBlob(Outputs()[0]);
1095
VLOG(2) << "ElementwiseSub took: " << t.MilliSeconds();
1100
REGISTER_CPU_OPERATOR(MPSCNNSub, MPSCNNSubOp);
1101
OPERATOR_SCHEMA(MPSCNNSub).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1103
class MPSCNNAddOp final : public Operator<CPUContext> {
1105
MPSCNNAddOp(const OperatorDef& operator_def, Workspace* ws)
1106
: Operator<CPUContext>(operator_def, ws) {}
1108
bool RunOnDevice() override {
1111
auto wrapper0 = Inputs()[0]->Get<MPSImageWrapper>();
1112
auto wrapper1 = Inputs()[1]->Get<MPSImageWrapper>();
1113
MPSImage* X0 = wrapper0.getImage();
1114
MPSImage* X1 = wrapper1.getImage();
1115
CAFFE_ENFORCE_EQ(wrapper0.getCommandBuffer(), wrapper1.getCommandBuffer());
1117
auto outputWrapper = MPSImageWrapper(
1123
X0.featureChannels);
1124
auto commandBuffer = outputWrapper.getCommandBuffer();
1125
MPSImage* output = outputWrapper.getImage();
1126
CAFFE_ENFORCE_EQ(X1.width, X0.width);
1127
CAFFE_ENFORCE_EQ(X1.height, X0.height);
1128
CAFFE_ENFORCE_EQ(X1.featureChannels, X0.featureChannels);
1129
id<MTLComputeCommandEncoder> encoder =
1130
[commandBuffer computeCommandEncoder];
1131
id<MTLComputePipelineState> state = getMPSCNNContext().getPipelineState(
1132
kernelFor(X0, @"elementwise_add", @"elementwise_add_nonarray"));
1134
[encoder setComputePipelineState:state];
1135
[encoder setTexture:[X0 texture] atIndex:0];
1136
[encoder setTexture:[X1 texture] atIndex:1];
1137
[encoder setTexture:[output texture] atIndex:2];
1138
const auto& launchParams =
1139
spatialPointwiseKernelLaunchParams(state, output);
1140
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1141
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1142
[encoder endEncoding];
1143
wrapper0.markRead();
1144
wrapper1.markRead();
1145
outputWrapper.copyToOutputBlob(Outputs()[0]);
1147
VLOG(2) << "ElementwiseAdd took: " << t.MilliSeconds();
1152
REGISTER_CPU_OPERATOR(MPSCNNAdd, MPSCNNAddOp);
1153
// Not really in-place per-se, but semantically is valid and preserves
1155
OPERATOR_SCHEMA(MPSCNNAdd).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1157
class MPSCNNAveragePoolOp final : public ConvPoolOpBase<CPUContext> {
1159
MPSCNNAveragePoolOp(const OperatorDef& operator_def, Workspace* ws)
1160
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
1161
OPERATOR_NEEDS_FEATURE(
1162
this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
1163
OPERATOR_NEEDS_FEATURE(
1164
kernel_h() == kernel_w(),
1165
"Metal only supports equal kernel dimension.");
1168
bool RunOnDeviceWithOrderNCHW() override {
1170
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1171
MPSImage* X = inputWrapper.getImage();
1173
if (!pool_ || this->global_pooling_) {
1174
caffe2::Timer consT;
1175
this->ComputePads({(int)X.height, (int)X.width});
1177
[[MPSCNNPoolingAverage alloc] initWithDevice:getMPSCNNContext().device
1178
kernelWidth:kernel_w()
1179
kernelHeight:kernel_h()
1180
strideInPixelsX:stride_w()
1181
strideInPixelsY:stride_h()];
1183
[pool_ setEdgeMode:MPSImageEdgeModeClamp];
1185
offset.x = computeMPSAlignOffset(kernel_w(), pad_l());
1186
offset.y = computeMPSAlignOffset(kernel_h(), pad_t());
1188
[pool_ setOffset:offset];
1189
VLOG(2) << "MPSCNNAveragePool PoolDesc took: " << consT.MilliSeconds();
1192
CAFFE_ENFORCE_EQ(pool_.strideInPixelsY, stride_h());
1193
CAFFE_ENFORCE_EQ(pool_.strideInPixelsX, stride_w());
1196
computeOutputHW(this, X.height, X.width, &output_height, &output_width);
1198
VLOG(2) << "Output height: " << output_height;
1199
VLOG(2) << "Output width:" << output_width;
1200
VLOG(2) << "Output channels:" << X.featureChannels;
1201
auto outputWrapper = MPSImageWrapper(
1208
auto commandBuffer = outputWrapper.getCommandBuffer();
1209
MPSImage* output = outputWrapper.getImage();
1210
CAFFE_ENFORCE_EQ(output.height, output_height);
1211
CAFFE_ENFORCE_EQ(output.width, output_width);
1212
[pool_ encodeToCommandBuffer:commandBuffer
1214
destinationImage:output];
1215
outputWrapper.copyToOutputBlob(Outputs()[0]);
1217
VLOG(2) << "MPSCNNAveragePool took: " << t.MilliSeconds();
1221
MPSCNNPoolingAverage* pool_{nullptr};
1224
REGISTER_CPU_OPERATOR(MPSCNNAveragePool, MPSCNNAveragePoolOp);
1225
OPERATOR_SCHEMA(MPSCNNAveragePool).NumInputs(1).NumOutputs(1);
1227
class MPSCNNMaxPoolOp final : public ConvPoolOpBase<CPUContext> {
1229
MPSCNNMaxPoolOp(const OperatorDef& operator_def, Workspace* ws)
1230
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
1231
OPERATOR_NEEDS_FEATURE(
1232
this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
1233
OPERATOR_NEEDS_FEATURE(
1234
kernel_h() == kernel_w(),
1235
"Metal only supports equal kernel dimension.");
1238
bool RunOnDeviceWithOrderNCHW() override {
1240
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1241
MPSImage* X = inputWrapper.getImage();
1243
if (!pool_ || this->global_pooling_) {
1244
caffe2::Timer consT;
1245
this->ComputePads({(int)X.height, (int)X.width});
1246
pool_ = [[MPSCNNPoolingMax alloc] initWithDevice:getMPSCNNContext().device
1247
kernelWidth:kernel_w()
1248
kernelHeight:kernel_h()
1249
strideInPixelsX:stride_w()
1250
strideInPixelsY:stride_h()];
1252
[pool_ setEdgeMode:MPSImageEdgeModeClamp];
1254
offset.x = computeMPSAlignOffset(kernel_w(), pad_l());
1255
offset.y = computeMPSAlignOffset(kernel_h(), pad_t());
1257
[pool_ setOffset:offset];
1258
VLOG(2) << "MPSCNNMaxPool PoolDesc took: " << consT.MilliSeconds();
1261
CAFFE_ENFORCE_EQ(pool_.strideInPixelsY, stride_h());
1262
CAFFE_ENFORCE_EQ(pool_.strideInPixelsX, stride_w());
1266
computeOutputHW(this, X.height, X.width, &output_height, &output_width);
1268
VLOG(2) << "Output height: " << output_height;
1269
VLOG(2) << "Output width:" << output_width;
1270
VLOG(2) << "Output channels:" << X.featureChannels;
1271
auto outputWrapper = MPSImageWrapper(
1278
auto commandBuffer = outputWrapper.getCommandBuffer();
1279
MPSImage* output = outputWrapper.getImage();
1280
CAFFE_ENFORCE_EQ(output.height, output_height);
1281
CAFFE_ENFORCE_EQ(output.width, output_width);
1282
[pool_ encodeToCommandBuffer:commandBuffer
1284
destinationImage:output];
1285
outputWrapper.copyToOutputBlob(Outputs()[0]);
1287
VLOG(2) << "MPSCNNMaxPool took: " << t.MilliSeconds();
1291
MPSCNNPoolingMax* pool_{nullptr};
1294
REGISTER_CPU_OPERATOR(MPSCNNMaxPool, MPSCNNMaxPoolOp);
1295
OPERATOR_SCHEMA(MPSCNNMaxPool).NumInputs(1).NumOutputs(1);
1297
class MPSCNNSoftmaxOp final : public Operator<CPUContext> {
1299
MPSCNNSoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
1300
: Operator<CPUContext>(operator_def, ws) {}
1302
bool RunOnDevice() override {
1304
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1305
MPSImage* X = inputWrapper.getImage();
1306
CAFFE_ENFORCE_EQ(X.height, 1);
1307
CAFFE_ENFORCE_EQ(X.width, 1);
1310
[[MPSCNNSoftMax alloc] initWithDevice:getMPSCNNContext().device];
1312
auto outputWrapper = MPSImageWrapper(
1319
auto commandBuffer = outputWrapper.getCommandBuffer();
1320
MPSImage* output = outputWrapper.getImage();
1321
[softmax_ encodeToCommandBuffer:commandBuffer
1323
destinationImage:output];
1324
outputWrapper.copyToOutputBlob(Outputs()[0]);
1325
VLOG(2) << "MPSCNNSoftmax took: " << t.MilliSeconds();
1329
MPSCNNSoftMax* softmax_{nullptr};
1332
REGISTER_CPU_OPERATOR(MPSCNNSoftmax, MPSCNNSoftmaxOp);
1333
OPERATOR_SCHEMA(MPSCNNSoftmax).NumInputs(1).NumOutputs(1);
1335
template <typename Neuron>
1336
class MPSCNNFullyConnectedOp final : public Operator<CPUContext> {
1338
MPSCNNFullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
1339
: Operator<CPUContext>(operator_def, ws) {}
1341
bool RunOnDevice() override {
1343
auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
1344
MPSImage* X = inputWrapper.getImage();
1345
const auto& W = Input(1);
1346
const auto& b = Input(2);
1348
const auto input_channels = W.dim32(1) / X.width / X.height;
1349
CAFFE_ENFORCE_EQ(input_channels, X.featureChannels);
1350
const auto output_channels = W.dim32(0);
1352
const auto M = output_channels;
1353
const auto kH = X.height;
1354
const auto kW = X.width;
1355
const auto C = input_channels;
1356
std::vector<float> refilter(M * kH * kW * C);
1357
auto* filter_ = W.template data<float>();
1358
for (auto m = 0; m < M; ++m) {
1359
for (auto c = 0; c < C; ++c) {
1360
for (auto kh = 0; kh < kH; ++kh) {
1361
for (auto kw = 0; kw < kW; ++kw) {
1362
// refilter[m][kh][kw][c]
1363
refilter[m * kH * kW * C + kh * kW * C + kw * C + c] =
1364
// filter[m][c][kh][kw]
1365
filter_[m * C * kH * kW + c * kH * kW + kh * kW + kw];
1371
MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
1372
cnnConvolutionDescriptorWithKernelWidth:X.width
1373
kernelHeight:X.height
1374
inputFeatureChannels:input_channels
1375
outputFeatureChannels:output_channels
1376
neuronFilter:Neuron::t()];
1377
auto data_source = [[ConvDataSource alloc]
1378
initWithWeight:refilter.data()
1379
bias:const_cast<float*>(b.template data<float>())
1381
fc_ = [[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
1382
weights:data_source];
1384
// Note that X.numberOfImages can change between calls, but X.height and
1385
// X.width are static by definition.
1386
VLOG(2) << "MPSCNNFC: " << X.numberOfImages << ", " << X.width << ", "
1387
<< X.height << ", " << X.featureChannels << ", " << output_channels;
1389
[fc_ setClipRect:MTLRegionMake3D(0, 0, 0, 1, 1, X.numberOfImages)];
1391
off.x = X.width / 2;
1392
off.y = X.height / 2;
1394
[fc_ setOffset:off];
1395
auto outputWrapper = MPSImageWrapper(
1396
this, &inputWrapper, X.numberOfImages, 1, 1, output_channels);
1397
auto commandBuffer = outputWrapper.getCommandBuffer();
1398
MPSImage* output = outputWrapper.getImage();
1399
[fc_ encodeToCommandBuffer:commandBuffer
1401
destinationImage:output];
1402
outputWrapper.copyToOutputBlob(Outputs()[0]);
1403
VLOG(2) << "MPSCNNFC took: " << t.MilliSeconds();
1407
MPSCNNConvolution* fc_{nullptr};
1410
#define INIT_FC_NEURON_OP(name, neuron) \
1411
REGISTER_CPU_OPERATOR(name, MPSCNNFullyConnectedOp<neuron>); \
1412
OPERATOR_SCHEMA(name).NumInputs(3).NumOutputs(1);
1414
INIT_FC_NEURON_OP(MPSCNNFC, EmptyNeuronInit);
1415
INIT_FC_NEURON_OP(MPSCNNFCRelu, ReluNeuronInit);
1416
#undef INIT_FC_NEURON_OP
1418
class MPSCNNDropoutOp final : public Operator<CPUContext> {
1420
MPSCNNDropoutOp(const OperatorDef& operator_def, Workspace* ws)
1421
: Operator<CPUContext>(operator_def, ws) {}
1423
// Just pass inputs through, since we assume inference-time only.
1424
bool RunOnDevice() override {
1425
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1426
inputWrapper.copyToOutputBlob(Outputs()[0]);
1431
REGISTER_CPU_OPERATOR(MPSCNNDropout, MPSCNNDropoutOp);
1432
// Never use the second output (the mask).
1433
OPERATOR_SCHEMA(MPSCNNDropout)
1436
.AllowInplace({{0, 0}});
1438
class MPSCNNConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
1440
MPSCNNConvTransposeOp(const OperatorDef& operator_def, Workspace* ws)
1441
: ConvTransposeUnpoolBase<CPUContext>(operator_def, ws) {
1442
OPERATOR_NEEDS_FEATURE(
1443
this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
1445
kernel_w(), kernel_h(), "Metal only supports equal kernel dimensions");
1448
bool RunOnDeviceWithOrderNCHW() override {
1450
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1452
MPSImage* X = inputWrapper.getImage();
1454
auto& filter = Input(FILTER);
1455
auto& bias = Input(BIAS);
1456
CAFFE_ENFORCE(filter.dim(), 4);
1457
const int output_channels = filter.dim32(1);
1458
const int input_channels = filter.dim32(0);
1460
CAFFE_ENFORCE(X.featureChannels == input_channels, "");
1461
CAFFE_ENFORCE(filter.dim32(2) == kernel_h(), "");
1462
CAFFE_ENFORCE(filter.dim32(3) == kernel_w(), "");
1463
CAFFE_ENFORCE(bias.dim() == 1, "");
1464
CAFFE_ENFORCE(bias.dim32(0) == output_channels, "");
1466
const auto kH = kernel_h();
1467
const auto kW = kernel_w();
1470
(X.height - 1) * stride_h() + kH - pad_b() - pad_t() + adj_h();
1472
(X.width - 1) * stride_w() + kW - pad_l() - pad_r() + adj_w();
1474
VLOG(2) << "Output height: " << output_height;
1475
VLOG(2) << "Output width:" << output_width;
1476
VLOG(2) << "Output channels:" << output_channels;
1478
auto outputWrapper = MPSImageWrapper(
1485
auto commandBuffer = outputWrapper.getCommandBuffer();
1487
bool runtimeAtLeastIOS11 = SYSTEM_VERSION_GREATER_THAN_OR_EQUAL_TO(@"11.0");
1489
if (!conv_trans_ && !conv_) {
1490
caffe2::Timer consT;
1491
std::vector<float> refilter(kH * kW * output_channels * input_channels);
1492
refilter.assign(kH * kW * output_channels * input_channels, 0.0f);
1493
TORCH_DCHECK_EQ(refilter.size(), filter.size());
1494
auto* filter_ = filter.template data<float>();
1495
// For iOS11+ Reformat weights from WT[IC][OC][kH][kW] to
1496
// W[OC][kH][kW][IC]; For previous versions, reformat weights
1497
// to W[kH][kW][OC][IC]
1498
// Also rotate the weight matrix spatially by 180 degrees
1499
for (auto oc = 0; oc < output_channels; ++oc) {
1500
for (auto ic = 0; ic < input_channels; ++ic) {
1501
for (auto kh = 0; kh < kH; ++kh) {
1502
for (auto kw = 0; kw < kW; ++kw) {
1503
const auto inputIdx =
1504
ic * output_channels * kH * kW + oc * kH * kW + kh * kW + kw;
1506
if (runtimeAtLeastIOS11) {
1507
outputIdx = oc * kH * kW * input_channels +
1508
(kH - 1 - kh) * kW * input_channels +
1509
(kW - 1 - kw) * input_channels + ic;
1511
outputIdx = kh * kW * output_channels * input_channels +
1512
kw * output_channels * input_channels +
1513
oc * input_channels + ic;
1515
TORCH_DCHECK_LT(inputIdx, filter.size());
1516
TORCH_DCHECK_LT(outputIdx, filter.size());
1517
refilter[outputIdx] = filter_[inputIdx];
1522
TORCH_DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW);
1523
// initialize data structures
1524
if (runtimeAtLeastIOS11) {
1525
MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
1526
cnnConvolutionDescriptorWithKernelWidth:kW
1528
inputFeatureChannels:input_channels
1529
outputFeatureChannels:output_channels];
1530
desc.strideInPixelsX = this->stride_w();
1531
desc.strideInPixelsY = this->stride_h();
1533
auto data_source = [[ConvDataSource alloc]
1534
initWithWeight:refilter.data()
1535
bias:const_cast<float*>(bias.data<float>())
1538
conv_trans_ = [[MPSCNNConvolutionTranspose alloc]
1539
initWithDevice:getMPSCNNContext().device
1540
weights:data_source];
1545
[conv_trans_ setOffset:offset];
1546
// kernel offset + padding offset
1547
conv_trans_.kernelOffsetX = kW / 2 - kW + 1 + this->pad_l();
1548
conv_trans_.kernelOffsetY = kH / 2 - kH + 1 + this->pad_t();
1549
VLOG(2) << "MPSCNNConvTranspose ConvDesc took: "
1550
<< consT.MilliSeconds();
1552
MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
1553
cnnConvolutionDescriptorWithKernelWidth:1
1555
inputFeatureChannels:input_channels
1556
outputFeatureChannels:output_channels * kH * kW
1558
// We need to zero-fill the bias here.
1559
std::vector<float> fakeBias;
1560
fakeBias.assign(output_channels * kH * kW, 0);
1562
desc.strideInPixelsX = 1;
1563
desc.strideInPixelsY = 1;
1565
[[ConvDataSource alloc] initWithWeight:refilter.data()
1566
bias:fakeBias.data()
1569
[[MPSCNNConvolution alloc] initWithDevice:getMPSCNNContext().device
1570
weights:data_source];
1571
[conv_ setEdgeMode:MPSImageEdgeModeZero];
1576
[conv_ setOffset:offset];
1578
const auto biasBytes = divRoundUp(bias.size(), 4) * 4 * 2;
1579
biasBuffer_ = [getMPSCNNContext().device
1580
newBufferWithLength:biasBytes
1581
options:MTLResourceOptionCPUCacheModeDefault];
1582
for (auto i = 0; i < bias.size(); ++i) {
1583
((float16_t*)[biasBuffer_ contents])[i] = bias.data<float>()[i];
1586
VLOG(2) << "MPSCNNConvTranspose ConvDesc took: "
1587
<< consT.MilliSeconds();
1588
} // data structure initialization
1590
CAFFE_ENFORCE((conv_trans_ && !conv_) || (!conv_trans_ && conv_));
1592
// run the computation
1594
MPSImage* output = outputWrapper.getImage();
1595
X = inputWrapper.getImage();
1596
CAFFE_ENFORCE_EQ(conv_trans_.groups, 1);
1597
[conv_trans_ encodeToCommandBuffer:commandBuffer
1599
destinationImage:output];
1601
CAFFE_ENFORCE_EQ(conv_.strideInPixelsY, 1);
1602
CAFFE_ENFORCE_EQ(conv_.strideInPixelsX, 1);
1603
CAFFE_ENFORCE_EQ(conv_.groups, 1);
1604
CAFFE_ENFORCE_EQ(conv_.inputFeatureChannels, input_channels);
1605
CAFFE_ENFORCE_EQ(conv_.outputFeatureChannels, output_channels * kH * kW);
1606
CAFFE_ENFORCE_EQ(conv_.kernelWidth, 1);
1607
CAFFE_ENFORCE_EQ(conv_.kernelHeight, 1);
1608
if (divRoundUp(X.numberOfImages * output_channels * kH * kW, 4) >
1609
kMetalMaxTextureArrLength) {
1610
LOG(INFO) << "ConvTranspose " << X.numberOfImages << " "
1611
<< output_channels << " " << kH << " " << kW;
1613
<< "arrayLength exceeds the maximum allowed length in texture";
1614
inputWrapper.cleanup();
1615
outputWrapper.cleanup();
1618
VLOG(2) << "ConvTranspose:" << output_channels << " " << kH << " " << kW
1619
<< " " << X.numberOfImages;
1621
auto gemmed = createTemporaryImage(
1627
output_channels * kH * kW);
1630
[conv_ encodeToCommandBuffer:commandBuffer
1632
destinationImage:gemmed];
1633
VLOG(2) << "MPSCNNConvTranspose GEMM took: " << gt.MilliSeconds();
1635
MPSImage* output = outputWrapper.getImage();
1639
id<MTLComputePipelineState> state =
1640
getMPSCNNContext().getSpecializedPipelineState(
1642
{{ushort(kernel_h()),
1648
ushort(output.featureChannels),
1649
ushort(output.numberOfImages),
1650
ushort(gemmed.height),
1651
ushort(gemmed.width)}});
1652
id<MTLComputeCommandEncoder> encoder =
1653
[commandBuffer computeCommandEncoder];
1654
[encoder setComputePipelineState:state];
1655
[encoder setTexture:[gemmed texture] atIndex:0];
1656
[encoder setTexture:[output texture] atIndex:1];
1657
[encoder setBuffer:biasBuffer_ offset:0 atIndex:0];
1658
const auto& launchParams =
1659
spatialPointwiseKernelLaunchParams(state, output);
1660
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1661
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1662
[encoder endEncoding];
1663
gemmed.readCount -= 1;
1664
VLOG(2) << "MPSCNNConvTranspose upscaling took: " << cit.MilliSeconds();
1667
outputWrapper.copyToOutputBlob(Outputs()[0]);
1668
VLOG(2) << "MPSCNNConvTranspose took: " << t.MilliSeconds();
1674
INPUT_TAGS(INPUT, FILTER, BIAS);
1675
MPSCNNConvolutionTranspose* conv_trans_{nullptr};
1676
id<MTLBuffer> biasBuffer_;
1677
MPSCNNConvolution* conv_{nullptr};
1681
#define INIT_CONV_TRANSPOSE_NEURON_OP(name) \
1682
REGISTER_CPU_OPERATOR(name, MPSCNNConvTransposeOp); \
1683
OPERATOR_SCHEMA(name).NumInputs(3).NumOutputs(1);
1685
INIT_CONV_TRANSPOSE_NEURON_OP(MPSCNNConvTranspose);
1686
#undef INIT_CONV_TRANSPOSE_NEURON_OP
1688
enum class InstanceNormFusionTy {
1693
template <InstanceNormFusionTy fusionTy>
1694
class MPSCNNInstanceNormOp final : public Operator<CPUContext> {
1696
MPSCNNInstanceNormOp(const OperatorDef& operator_def, Workspace* ws)
1697
: Operator<CPUContext>(operator_def, ws) {}
1699
bool RunOnDevice() override {
1700
auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
1701
MPSImage* X = inputWrapper.getImage();
1703
const auto& scale = Input(1);
1704
const auto& bias = Input(2);
1705
CAFFE_ENFORCE_EQ(scale.size(), X.featureChannels);
1706
CAFFE_ENFORCE_EQ(bias.size(), X.featureChannels);
1707
const auto scaleBytes = divRoundUp(scale.size(), 4) * 4 * 2;
1708
if (!scaleBuffer_ || !biasBuffer_ || scaleBuffer_.length != scaleBytes ||
1709
biasBuffer_.length != scaleBytes) {
1711
// Round-up to nearest multiple of 4,
1712
// so accesses to X[i * 4 + 3] in kernel is valid.
1713
scaleBuffer_ = [getMPSCNNContext().device
1714
newBufferWithLength:scaleBytes
1715
options:MTLResourceOptionCPUCacheModeDefault];
1716
biasBuffer_ = [getMPSCNNContext().device
1717
newBufferWithLength:scaleBytes
1718
options:MTLResourceOptionCPUCacheModeDefault];
1719
for (auto i = 0; i < scale.size(); ++i) {
1720
((float16_t*)[scaleBuffer_ contents])[i] =
1721
scale.template data<float>()[i];
1723
for (auto i = 0; i < bias.size(); ++i) {
1724
((float16_t*)[biasBuffer_ contents])[i] =
1725
bias.template data<float>()[i];
1727
if (fusionTy == InstanceNormFusionTy::PRELU) {
1728
const auto& preluWeight = Input(3);
1729
preluWeightBuffer_ = [getMPSCNNContext().device
1730
newBufferWithLength:divRoundUp(preluWeight.size(), 4) * 4 * 2
1731
options:MTLResourceOptionCPUCacheModeDefault];
1732
for (auto i = 0; i < preluWeight.size(); ++i) {
1733
((float16_t*)[preluWeightBuffer_ contents])[i] =
1734
preluWeight.template data<float>()[i];
1737
VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
1740
auto outputWrapper = MPSImageWrapper(
1747
auto commandBuffer = inputWrapper.getCommandBuffer();
1748
MPSImage* output = outputWrapper.getImage();
1751
id<MTLComputeCommandEncoder> encoder =
1752
[commandBuffer computeCommandEncoder];
1753
id<MTLComputePipelineState> state =
1754
getMPSCNNContext().getSpecializedPipelineState(
1755
kernelFor(X, @"instance_norm", @"instance_norm_nonarray"),
1756
{{ushort(X.featureChannels),
1757
fusionTy == InstanceNormFusionTy::PRELU ? ushort(Input(3).size())
1760
[encoder setComputePipelineState:state];
1761
[encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
1762
[encoder setBuffer:biasBuffer_ offset:0 atIndex:1];
1763
[encoder setTexture:[X texture] atIndex:0];
1764
[encoder setTexture:[output texture] atIndex:1];
1765
if (fusionTy == InstanceNormFusionTy::PRELU) {
1766
[encoder setBuffer:preluWeightBuffer_ offset:0 atIndex:2];
1768
[encoder dispatchThreadgroups:MTLSizeMake(
1772
divRoundUp(X.featureChannels, 4))
1773
threadsPerThreadgroup:MTLSizeMake(16, 16, 1)];
1774
[encoder endEncoding];
1775
inputWrapper.markRead();
1776
VLOG(2) << "InstanceNorm took: " << t.MilliSeconds();
1777
outputWrapper.copyToOutputBlob(Outputs()[0]);
1783
id<MTLBuffer> scaleBuffer_;
1784
id<MTLBuffer> biasBuffer_;
1785
id<MTLBuffer> preluWeightBuffer_;
1788
REGISTER_CPU_OPERATOR(
1790
MPSCNNInstanceNormOp<InstanceNormFusionTy::NONE>);
1791
OPERATOR_SCHEMA(MPSCNNInstanceNorm).NumInputs(3).NumOutputs(1);
1792
REGISTER_CPU_OPERATOR(
1793
MPSCNNInstanceNormPRelu,
1794
MPSCNNInstanceNormOp<InstanceNormFusionTy::PRELU>);
1795
OPERATOR_SCHEMA(MPSCNNInstanceNormPRelu).NumInputs(4).NumOutputs(1);
1797
class MPSCNNNormalizePlanarYUVOp final : public Operator<CPUContext> {
1799
MPSCNNNormalizePlanarYUVOp(const OperatorDef& operator_def, Workspace* ws)
1800
: Operator<CPUContext>(operator_def, ws) {}
1802
bool RunOnDevice() override {
1803
auto inputWrapper = Inputs()[0]->template Get<MPSImageWrapper>();
1804
MPSImage* X = inputWrapper.getImage();
1806
const auto& mean = Input(1);
1807
const auto& std = Input(2);
1808
CAFFE_ENFORCE_EQ(mean.size(), X.featureChannels);
1809
CAFFE_ENFORCE_EQ(std.size(), X.featureChannels);
1810
const auto scaleBytes = divRoundUp(mean.size(), 4) * 4 * 2;
1811
if (!scaleBuffer_ || !shiftBuffer_ || scaleBuffer_.length != scaleBytes ||
1812
shiftBuffer_.length != scaleBytes) {
1814
scaleBuffer_ = [getMPSCNNContext().device
1815
newBufferWithLength:scaleBytes
1816
options:MTLResourceOptionCPUCacheModeDefault];
1817
shiftBuffer_ = [getMPSCNNContext().device
1818
newBufferWithLength:scaleBytes
1819
options:MTLResourceOptionCPUCacheModeDefault];
1820
// op computes (X - mean) / std = X * 1/std + (-mean/std)
1821
// Thus set scale = 1.0/std, shift = (-mean/std)
1822
for (auto i = 0; i < mean.size(); ++i) {
1823
((float16_t*)[scaleBuffer_ contents])[i] =
1824
1.0 / double(std.template data<float>()[i]);
1825
((float16_t*)[shiftBuffer_ contents])[i] =
1826
double(-mean.template data<float>()[i]) /
1827
double(std.template data<float>()[i]);
1829
VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
1832
auto outputWrapper = MPSImageWrapper(
1839
auto commandBuffer = inputWrapper.getCommandBuffer();
1840
MPSImage* output = outputWrapper.getImage();
1843
id<MTLComputeCommandEncoder> encoder =
1844
[commandBuffer computeCommandEncoder];
1845
id<MTLComputePipelineState> state =
1846
getMPSCNNContext().getSpecializedPipelineState(
1847
kernelFor(X, @"affine", @"affine_nonarray"),
1848
{ushort(X.featureChannels)});
1850
[encoder setComputePipelineState:state];
1851
[encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
1852
[encoder setBuffer:shiftBuffer_ offset:0 atIndex:1];
1853
[encoder setTexture:[X texture] atIndex:0];
1854
[encoder setTexture:[output texture] atIndex:1];
1855
const auto& launchParams =
1856
spatialPointwiseKernelLaunchParams(state, output);
1857
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1858
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1859
[encoder endEncoding];
1860
inputWrapper.markRead();
1861
VLOG(2) << "InstanceNorm took: " << t.MilliSeconds();
1862
outputWrapper.copyToOutputBlob(Outputs()[0]);
1868
id<MTLBuffer> scaleBuffer_;
1869
id<MTLBuffer> shiftBuffer_;
1872
REGISTER_CPU_OPERATOR(MPSCNNNormalizePlanarYUV, MPSCNNNormalizePlanarYUVOp);
1873
OPERATOR_SCHEMA(MPSCNNNormalizePlanarYUV).NumInputs(3).NumOutputs(1);
1875
class MPSCNNPReluOp final : public Operator<CPUContext> {
1877
MPSCNNPReluOp(const OperatorDef& operator_def, Workspace* ws)
1878
: Operator<CPUContext>(operator_def, ws) {}
1880
bool RunOnDevice() override {
1881
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1882
const MPSImage* X = inputWrapper.getImage();
1884
const auto& scale = Input(1);
1885
const auto scaleBytes = divRoundUp(scale.size(), 4) * 4 * 2;
1886
if (!scaleBuffer_ || scaleBuffer_.length != scaleBytes) {
1888
scaleBuffer_ = [getMPSCNNContext().device
1889
newBufferWithLength:scaleBytes
1890
options:MTLResourceOptionCPUCacheModeDefault];
1891
for (auto i = 0; i < scale.size(); ++i) {
1892
((float16_t*)[scaleBuffer_ contents])[i] = scale.data<float>()[i];
1894
VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
1897
auto outputWrapper = MPSImageWrapper(
1904
auto commandBuffer = inputWrapper.getCommandBuffer();
1905
MPSImage* output = outputWrapper.getImage();
1907
id<MTLComputeCommandEncoder> encoder =
1908
[commandBuffer computeCommandEncoder];
1909
id<MTLComputePipelineState> state =
1910
getMPSCNNContext().getSpecializedPipelineState(
1911
kernelFor(X, @"prelu_nonshared", @"prelu_nonshared_nonarray"),
1912
{{ushort(X.featureChannels), ushort(scale.size())}});
1914
[encoder setComputePipelineState:state];
1915
[encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
1916
[encoder setTexture:[X texture] atIndex:0];
1917
[encoder setTexture:[output texture] atIndex:1];
1919
const auto& launchParams =
1920
spatialPointwiseKernelLaunchParams(state, output);
1921
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
1922
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
1923
[encoder endEncoding];
1924
inputWrapper.markRead();
1925
VLOG(2) << "PRelu took: " << t.MilliSeconds();
1926
outputWrapper.copyToOutputBlob(Outputs()[0]);
1932
id<MTLBuffer> scaleBuffer_;
1935
REGISTER_CPU_OPERATOR(MPSCNNPRelu, MPSCNNPReluOp);
1936
// Allow in-place isn't *really* valid here, since nothing is in-place for Metal
1937
// texture arrays, but requires re-export.
1938
OPERATOR_SCHEMA(MPSCNNPRelu).NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});
1940
class MPSCNNRoIWarpOp final : public Operator<CPUContext> {
1942
MPSCNNRoIWarpOp(const OperatorDef& operator_def, Workspace* ws)
1943
: Operator<CPUContext>(operator_def, ws),
1945
OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
1946
pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
1947
pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
1949
OperatorBase::GetSingleArgument<int>("sampling_ratio", -1)) {
1950
CAFFE_ENFORCE_GT(spatial_scale_, 0);
1951
CAFFE_ENFORCE_GT(pooled_height_, 0);
1952
CAFFE_ENFORCE_GT(pooled_width_, 0);
1953
CAFFE_ENFORCE_GE(sampling_ratio_, 0);
1954
VLOG(1) << "spatial_scale: " << spatial_scale_;
1955
VLOG(1) << "pooled_h: " << pooled_height_;
1956
VLOG(1) << "pooled_w: " << pooled_width_;
1957
VLOG(1) << "sampling_ratio: " << sampling_ratio_;
1960
bool RunOnDevice() override {
1962
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
1963
auto X = inputWrapper.getImage();
1964
CAFFE_ENFORCE_EQ(X.numberOfImages, 1);
1965
const auto& R = Input(1);
1966
CAFFE_ENFORCE_EQ(R.dim(), 2);
1967
CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
1968
const auto roiBytes = R.dim32(0) * 4 * sizeof(float16_t);
1969
if (!roiBuffer_ || roiBuffer_.length != roiBytes) {
1971
roiBuffer_ = [getMPSCNNContext().device
1972
newBufferWithLength:roiBytes
1973
options:MTLResourceOptionCPUCacheModeDefault];
1975
float16_t* roiBuffer = (float16_t*)[roiBuffer_ contents];
1976
// Help compiler generate vcvt?
1977
const auto Rdim = R.dim32(1);
1978
CAFFE_ENFORCE(Rdim == 4 || Rdim == 5);
1979
auto off = Rdim == 5 ? 1 : 0;
1980
for (auto i = 0; i < R.dim32(0); ++i) {
1982
// only handle batch-size of one, so the batch index must be one.
1983
CAFFE_ENFORCE_EQ(R.data<float>()[i * Rdim], 0.0);
1985
roiBuffer[i * 4 + 0] = R.data<float>()[i * Rdim + off + 0];
1986
roiBuffer[i * 4 + 1] = R.data<float>()[i * Rdim + off + 1];
1987
roiBuffer[i * 4 + 2] = R.data<float>()[i * Rdim + off + 2];
1988
roiBuffer[i * 4 + 3] = R.data<float>()[i * Rdim + off + 3];
1990
auto featureChannels = X.featureChannels;
1991
VLOG(1) << "RoIWarp input size:" << X.numberOfImages << " "
1992
<< featureChannels << " " << X.height << " " << X.width;
1993
VLOG(1) << "RoIWarp output size:" << R.dim32(0) << " " << featureChannels
1994
<< " " << pooled_width_ << " " << pooled_height_;
1995
if (R.dim32(0) <= 0) {
1996
LOG(ERROR) << "number of RoIs <= 0 in RoIWarp " << R.dim32(0);
1997
inputWrapper.cleanup();
2000
if (divRoundUp(R.dim32(0) * featureChannels, 4) >
2001
kMetalMaxTextureArrLength) {
2002
LOG(INFO) << "MPSCNNRoIWarp " << R.dim32(0) << " " << featureChannels;
2003
LOG(ERROR) << "arrayLength exceeds the maximum allowed length in texture";
2004
inputWrapper.cleanup();
2007
auto outputWrapper = MPSImageWrapper(
2014
auto commandBuffer = outputWrapper.getCommandBuffer();
2015
MPSImage* output = outputWrapper.getImage();
2016
VLOG(1) << "output: " << output.numberOfImages << ", "
2017
<< output.featureChannels << ", " << output.height << ", "
2019
id<MTLComputeCommandEncoder> encoder =
2020
[commandBuffer computeCommandEncoder];
2021
id<MTLComputePipelineState> state =
2022
getMPSCNNContext().getSpecializedPipelineState(
2024
{{ushort(spatial_scale_ * 10000),
2025
ushort(sampling_ratio_),
2026
ushort(featureChannels),
2027
ushort(X.numberOfImages),
2028
ushort(output.numberOfImages)}});
2030
[encoder setComputePipelineState:state];
2031
[encoder setBuffer:roiBuffer_ offset:0 atIndex:0];
2032
[encoder setTexture:[X texture] atIndex:0];
2033
[encoder setTexture:[output texture] atIndex:1];
2035
const auto& launchParams =
2036
spatialPointwiseKernelLaunchParams(state, output);
2037
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2038
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2039
[encoder endEncoding];
2040
inputWrapper.markRead();
2041
VLOG(2) << "RoIWarp took: " << t.MilliSeconds();
2042
VLOG(1) << "ROIWarp size: " << output.numberOfImages << ", "
2043
<< output.featureChannels << ", " << output.height << ", "
2045
outputWrapper.copyToOutputBlob(Outputs()[0]);
2051
float spatial_scale_;
2054
int sampling_ratio_;
2056
id<MTLBuffer> roiBuffer_;
2059
REGISTER_CPU_OPERATOR(MPSCNNRoIWarp, MPSCNNRoIWarpOp);
2060
OPERATOR_SCHEMA(MPSCNNRoIWarp).NumInputs(2).NumOutputs(1);
2062
class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
2064
MPSCNNGenerateProposalsCPPOp(const OperatorDef& operator_def, Workspace* ws)
2065
: Operator<CPUContext>(operator_def, ws),
2067
OperatorBase::GetSingleArgument<float>("spatial_scale", 1.0 / 16)),
2068
feat_stride_(1.0 / spatial_scale_),
2070
OperatorBase::GetSingleArgument<int>("pre_nms_topN", 6000)),
2072
OperatorBase::GetSingleArgument<int>("post_nms_topN", 300)),
2074
OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7f)),
2075
rpn_min_size_(OperatorBase::GetSingleArgument<float>("min_size", 16)),
2077
this->template GetSingleArgument<bool>("legacy_plus_one", true)) {}
2079
template <class Derived1, class Derived2>
2080
std::vector<int> nms_metal(
2081
const Eigen::ArrayBase<Derived1>& proposals, // EArrXXf
2082
const Eigen::ArrayBase<Derived2>& scores, // EArrXf
2083
const std::vector<int>& sorted_indices,
2084
float thresh) const {
2085
CAFFE_ENFORCE_EQ(proposals.rows(), scores.rows());
2086
CAFFE_ENFORCE_EQ(proposals.cols(), 4);
2087
CAFFE_ENFORCE_EQ(scores.cols(), 1);
2088
CAFFE_ENFORCE_LE(sorted_indices.size(), proposals.rows());
2090
std::vector<float> proposals_cpu(proposals.size());
2091
Eigen::Map<ERArrXXf>(
2092
&proposals_cpu[0], proposals.rows(), proposals.cols()) = proposals;
2094
int box_num = sorted_indices.size();
2095
int col_blocks = divRoundUp(box_num, maxThreadsPerThreadgroup);
2096
auto pre_nms_size = box_num;
2097
auto preNmsProposalsBuffer_ = [getMPSCNNContext().device
2098
newBufferWithBytes:proposals_cpu.data()
2099
length:proposals.size() * sizeof(float)
2100
options:MTLResourceOptionCPUCacheModeDefault];
2101
auto sortedIndicesBuffer_ = [getMPSCNNContext().device
2102
newBufferWithBytes:sorted_indices.data()
2103
length:pre_nms_size * sizeof(int)
2104
options:MTLResourceOptionCPUCacheModeDefault];
2106
int pose_nms_size = fmin(rpn_post_nms_topN_, pre_nms_size);
2107
// round pose_nms_size up to the next power of 2
2108
int batch_size = pow(2, ceil(log(pose_nms_size) / log(2)));
2110
auto maskBuffer_ = [getMPSCNNContext().device
2111
newBufferWithLength:batch_size * col_blocks * sizeof(uint32_t)
2112
options:MTLResourceOptionCPUCacheModeDefault];
2113
std::vector<uint32_t> masks(batch_size * col_blocks);
2115
std::vector<int> keep(pose_nms_size);
2116
int num_to_keep = 0;
2117
bool terminate = false;
2118
std::vector<uint32_t> remv(col_blocks);
2120
for (int offset = 0; !terminate && offset < box_num; offset += batch_size) {
2121
auto commandBuffer = [getMPSCNNContext().commandQueue commandBuffer];
2122
auto encoder = [commandBuffer computeCommandEncoder];
2123
auto state = getMPSCNNContext().getSpecializedPipelineState(
2125
{{ushort(batch_size),
2126
maxThreadsPerThreadgroup,
2127
ushort(rpn_nms_thresh_ * 10000),
2129
[encoder setComputePipelineState:state];
2130
[encoder setBuffer:maskBuffer_ offset:0 atIndex:0];
2131
[encoder setBuffer:preNmsProposalsBuffer_ offset:0 atIndex:1];
2132
[encoder setBuffer:sortedIndicesBuffer_ offset:0 atIndex:2];
2133
const auto threadsPerThreadgroup =
2134
MTLSizeMake(maxThreadsPerThreadgroup, 1, 1);
2135
const auto threadgroupsPerGrid = MTLSizeMake(
2136
divRoundUp(batch_size, maxThreadsPerThreadgroup),
2137
divRoundUp(box_num, maxThreadsPerThreadgroup),
2139
[encoder dispatchThreadgroups:threadgroupsPerGrid
2140
threadsPerThreadgroup:threadsPerThreadgroup];
2141
[encoder endEncoding];
2142
[commandBuffer commit];
2143
[commandBuffer waitUntilCompleted];
2144
uint32_t* maskBufferPointer = (uint32_t*)[maskBuffer_ contents];
2147
maskBufferPointer + (maskBuffer_.length / sizeof(uint32_t)),
2150
for (int i = offset; i < fmin(offset + batch_size, box_num); ++i) {
2151
int nblock = i / maxThreadsPerThreadgroup;
2152
int inblock = i % maxThreadsPerThreadgroup;
2153
if (!(remv[nblock] & (1U << inblock))) {
2154
keep[num_to_keep++] = sorted_indices[i];
2155
if (num_to_keep >= pose_nms_size) {
2159
uint* p = &masks[0] + (i - offset) * col_blocks;
2160
for (int j = nblock; j < col_blocks; j++) {
2166
keep.resize(num_to_keep);
2169
void ProposalsForOneImage(
2170
const Eigen::Array3f& im_info,
2171
const Eigen::Map<const ERMatXf>& all_anchors,
2172
const utils::ConstTensorView<float>& bbox_deltas_tensor,
2173
const utils::ConstTensorView<float>& scores_tensor,
2174
ERArrXXf* out_boxes,
2175
EArrXf* out_probs) const {
2176
const auto& pre_nms_topN = rpn_pre_nms_topN_;
2177
const auto& post_nms_topN = rpn_post_nms_topN_;
2178
const auto& nms_thresh = rpn_nms_thresh_;
2179
const auto& min_size = rpn_min_size_;
2181
// Transpose and reshape predicted bbox transformations to get them
2182
// into the same order as the anchors:
2183
// - bbox deltas will be (4 * A, H, W) format from conv output
2184
// - transpose to (H, W, 4 * A)
2185
// - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
2186
// in slowest to fastest order to match the enumerated anchors
2187
CAFFE_ENFORCE_EQ(bbox_deltas_tensor.ndim(), 3);
2188
CAFFE_ENFORCE_EQ(bbox_deltas_tensor.dim(0) % 4, 0);
2189
auto A = bbox_deltas_tensor.dim(0) / 4;
2190
auto H = bbox_deltas_tensor.dim(1);
2191
auto W = bbox_deltas_tensor.dim(2);
2192
// equivalent to python code
2193
// bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape((-1, 4))
2194
ERArrXXf bbox_deltas(H * W * A, 4);
2195
Eigen::Map<ERMatXf>(bbox_deltas.data(), H * W, 4 * A) =
2196
Eigen::Map<const ERMatXf>(bbox_deltas_tensor.data(), A * 4, H * W)
2198
CAFFE_ENFORCE_EQ(bbox_deltas.rows(), all_anchors.rows());
2200
// - scores are (A, H, W) format from conv output
2201
// - transpose to (H, W, A)
2202
// - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
2203
// to match the order of anchors and bbox_deltas
2204
CAFFE_ENFORCE_EQ(scores_tensor.ndim(), 3);
2205
CAFFE_ENFORCE_EQ(scores_tensor.dims(), (vector<int>{A, H, W}));
2206
// equivalent to python code
2207
// scores = scores.transpose((1, 2, 0)).reshape((-1, 1))
2208
EArrXf scores(scores_tensor.size());
2209
Eigen::Map<ERMatXf>(scores.data(), H * W, A) =
2210
Eigen::Map<const ERMatXf>(scores_tensor.data(), A, H * W).transpose();
2211
// Transform anchors into proposals via bbox transformations
2212
auto proposals = utils::bbox_transform(
2213
all_anchors.array(),
2215
std::vector<float>{1.0, 1.0, 1.0, 1.0},
2216
utils::BBOX_XFORM_CLIP_DEFAULT,
2219
// 2. clip proposals to image (may result in proposals with zero area
2220
// that will be removed in the next step)
2221
proposals = utils::clip_boxes(
2222
proposals, im_info[0], im_info[1], 1.0, legacy_plus_one_);
2224
// 3. remove predicted boxes with either height or width < min_size
2226
utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_);
2228
TORCH_DCHECK_LE(keep.size(), scores.size());
2230
// 4. sort all (proposal, score) pairs by score from highest to lowest
2231
// 5. take top pre_nms_topN (e.g. 6000)
2232
std::sort(keep.begin(), keep.end(), [&scores](int lhs, int rhs) {
2233
return scores[lhs] > scores[rhs];
2236
if (pre_nms_topN > 0 && pre_nms_topN < keep.size()) {
2237
keep.resize(pre_nms_topN);
2240
// 6. apply loose nms (e.g. threshold = 0.7)
2241
// 7. take after_nms_topN (e.g. 300)
2242
// 8. return the top proposals (-> RoIs top)
2243
keep = nms_metal(proposals, scores, keep, nms_thresh);
2244
if (post_nms_topN > 0 && post_nms_topN < keep.size()) {
2245
keep.resize(post_nms_topN);
2248
utils::GetSubArrayRows(proposals, utils::AsEArrXt(keep), out_boxes);
2249
utils::GetSubArray(scores, utils::AsEArrXt(keep), out_probs);
2252
bool RunOnDevice() override {
2253
const auto& scores = Input(0);
2254
const auto& bbox_deltas = Input(1);
2255
const auto& im_info_tensor = Input(2);
2256
const auto& anchors = Input(3);
2257
auto* out_rois = Output(0);
2258
auto* out_rois_probs = Output(1);
2260
CAFFE_ENFORCE_EQ(scores.dim(), 4, scores.dim());
2261
CAFFE_ENFORCE(scores.template IsType<float>(), scores.meta().name());
2262
const auto num_images = scores.size(0);
2263
const auto A = scores.size(1);
2264
const auto height = scores.size(2);
2265
const auto width = scores.size(3);
2266
const auto K = height * width;
2268
// bbox_deltas: (num_images, A * 4, H, W)
2270
bbox_deltas.sizes(), (vector<int64_t>{num_images, 4 * A, height, width}));
2272
// im_info_tensor: (num_images, 3), format [height, width, scale; ...]
2273
CAFFE_ENFORCE_EQ(im_info_tensor.sizes(), (vector<int64_t>{num_images, 3}));
2275
im_info_tensor.template IsType<float>(), im_info_tensor.meta().name());
2278
CAFFE_ENFORCE_EQ(anchors.sizes(), (vector<int64_t>{A, 4}));
2279
CAFFE_ENFORCE(anchors.template IsType<float>(), anchors.meta().name());
2280
// Broadcast the anchors to all pixels
2281
auto all_anchors_vec =
2282
utils::ComputeAllAnchors(anchors, height, width, feat_stride_);
2283
Eigen::Map<const ERMatXf> all_anchors(all_anchors_vec.data(), K * A, 4);
2285
Eigen::Map<const ERArrXXf> im_info(
2286
im_info_tensor.data<float>(),
2287
im_info_tensor.size(0),
2288
im_info_tensor.size(1));
2290
const int roi_col_count = 5;
2291
out_rois->Resize(0, roi_col_count);
2292
out_rois_probs->Resize(0);
2294
// Use openmp for acceleration?
2295
for (int i = 0; i < num_images; i++) {
2296
auto cur_im_info = im_info.row(i);
2297
auto cur_bbox_deltas = GetSubTensorView<float>(bbox_deltas, i);
2298
auto cur_scores = GetSubTensorView<float>(scores, i);
2300
ERArrXXf im_i_boxes;
2302
ProposalsForOneImage(
2310
int csz = im_i_boxes.rows();
2311
int cur_start_idx = out_rois->size(0);
2313
out_rois->Extend(csz, 50);
2314
out_rois_probs->Extend(csz, 50);
2317
Eigen::Map<ERArrXXf> cur_rois(
2318
out_rois->mutable_data<float>() + cur_start_idx * roi_col_count,
2321
cur_rois.col(0).setConstant(i);
2322
cur_rois.block(0, 1, csz, 4) = im_i_boxes;
2326
out_rois_probs->mutable_data<float>() + cur_start_idx, csz) =
2334
// spatial_scale_ must be declared before feat_stride_
2335
float spatial_scale_{1.0};
2336
float feat_stride_{1.0};
2338
// RPN_PRE_NMS_TOP_N
2339
ushort rpn_pre_nms_topN_{6000};
2340
// RPN_POST_NMS_TOP_N
2341
ushort rpn_post_nms_topN_{300};
2343
float rpn_nms_thresh_{0.7};
2345
float rpn_min_size_{16};
2346
// The infamous "+ 1" for box width and height dating back to the DPM days
2347
bool legacy_plus_one_{true};
2348
// threads per thread group, used in nms
2349
ushort maxThreadsPerThreadgroup{32};
2352
REGISTER_CPU_OPERATOR(MPSCNNGenerateProposalsCPP, MPSCNNGenerateProposalsCPPOp);
2353
OPERATOR_SCHEMA(MPSCNNGenerateProposalsCPP).NumInputs(4).NumOutputs(2);
2355
class MPSCNNSpatialBNOp final : public SpatialBNOp<CPUContext> {
2357
MPSCNNSpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
2358
: SpatialBNOp<CPUContext>(operator_def, ws) {}
2360
bool RunOnDevice() override {
2361
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
2362
const MPSImage* X = inputWrapper.getImage();
2363
const auto& scale = Input(SCALE);
2364
const auto& bias = Input(BIAS);
2365
const auto& var = Input(EST_VAR);
2366
const auto& mean = Input(EST_MEAN);
2367
CAFFE_ENFORCE_EQ(scale.size(), X.featureChannels);
2368
CAFFE_ENFORCE_EQ(bias.size(), X.featureChannels);
2369
CAFFE_ENFORCE_EQ(var.size(), X.featureChannels);
2370
CAFFE_ENFORCE_EQ(mean.size(), X.featureChannels);
2372
const auto scaleBytes = divRoundUp(scale.size(), 4) * 4 * 2;
2373
if (!scaleBuffer_ || scaleBuffer_.length != scaleBytes) {
2375
scaleBuffer_ = [getMPSCNNContext().device
2376
newBufferWithLength:scaleBytes
2377
options:MTLResourceOptionCPUCacheModeDefault];
2378
shiftBuffer_ = [getMPSCNNContext().device
2379
newBufferWithLength:scaleBytes
2380
options:MTLResourceOptionCPUCacheModeDefault];
2381
for (auto i = 0; i < scale.size(); ++i) {
2382
// We can fuse the output computation as follows:
2383
// ((x - est_mean) * (inv_var) * scale + bias
2385
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
2387
const auto inv_std = 1.0 / std::sqrt(var.data<float>()[i] + epsilon_);
2388
((float16_t*)[scaleBuffer_ contents])[i] =
2389
scale.data<float>()[i] * inv_std;
2390
((float16_t*)[shiftBuffer_ contents])[i] = bias.data<float>()[i] -
2391
mean.data<float>()[i] * inv_std * scale.data<float>()[i];
2393
VLOG(2) << "Buffer setup took: " << cvt.MilliSeconds();
2396
auto outputWrapper = MPSImageWrapper(
2403
auto commandBuffer = outputWrapper.getCommandBuffer();
2404
MPSImage* output = outputWrapper.getImage();
2406
id<MTLComputeCommandEncoder> encoder =
2407
[commandBuffer computeCommandEncoder];
2408
id<MTLComputePipelineState> state =
2409
getMPSCNNContext().getSpecializedPipelineState(
2410
kernelFor(output, @"affine", @"affine_nonarray"),
2411
{ushort(X.featureChannels)});
2413
[encoder setComputePipelineState:state];
2414
[encoder setBuffer:scaleBuffer_ offset:0 atIndex:0];
2415
[encoder setBuffer:shiftBuffer_ offset:0 atIndex:1];
2416
[encoder setTexture:[X texture] atIndex:0];
2417
[encoder setTexture:[output texture] atIndex:1];
2419
const auto& launchParams =
2420
spatialPointwiseKernelLaunchParams(state, output);
2421
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2422
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2423
[encoder endEncoding];
2424
inputWrapper.markRead();
2425
VLOG(2) << "SpatialBN took: " << t.MilliSeconds();
2426
outputWrapper.copyToOutputBlob(Outputs()[0]);
2432
id<MTLBuffer> scaleBuffer_;
2433
id<MTLBuffer> shiftBuffer_;
2436
REGISTER_CPU_OPERATOR(MPSCNNSpatialBN, MPSCNNSpatialBNOp);
2437
OPERATOR_SCHEMA(MPSCNNSpatialBN).NumInputs(5).NumOutputs(1);
2439
class MPSCNNConcatOp final : public Operator<CPUContext> {
2441
MPSCNNConcatOp(const OperatorDef& operator_def, Workspace* ws)
2442
: Operator<CPUContext>(operator_def, ws) {
2443
// Only handle three inputs for now.
2444
OPERATOR_NEEDS_FEATURE(
2445
Inputs().size() <= 4, "MPSCNNConcat only handles up to four inputs");
2448
bool RunOnDevice() override {
2449
auto Wrapper = [&](size_t i) {
2450
return Inputs()[i]->template Get<MPSImageWrapper>();
2452
auto cb = [&](size_t i) { return Wrapper(i).getCommandBuffer(); };
2453
auto X = [&](size_t i) { return Wrapper(i).getImage(); };
2455
// C0, C1, C2, C3, C, N
2456
std::vector<ushort> channels = {
2457
{0, 0, 0, 0, 0, ushort(X(0).numberOfImages)}};
2458
size_t channelCount = 0;
2459
for (auto i = 0; i < Inputs().size(); ++i) {
2460
// this does not hold for non-temp images inputs
2461
CAFFE_ENFORCE_EQ(cb(0), cb(i));
2462
CAFFE_ENFORCE_EQ(X(0).height, X(i).height);
2463
CAFFE_ENFORCE_EQ(X(0).width, X(i).width);
2464
channels[i] = X(i).featureChannels;
2465
channelCount += X(i).featureChannels;
2467
channels[4] = channelCount;
2469
auto wrapper0 = Inputs()[0]->template Get<MPSImageWrapper>();
2470
auto outputWrapper = MPSImageWrapper(
2473
X(0).numberOfImages,
2477
auto commandBuffer = outputWrapper.getCommandBuffer();
2478
MPSImage* output = outputWrapper.getImage();
2480
id<MTLComputeCommandEncoder> encoder =
2481
[commandBuffer computeCommandEncoder];
2482
id<MTLComputePipelineState> state =
2483
getMPSCNNContext().getSpecializedPipelineState(@"concat", channels);
2485
[encoder setComputePipelineState:state];
2486
for (auto i = 0; i < Inputs().size(); ++i) {
2487
[encoder setTexture:[X(i) texture] atIndex:i];
2489
[encoder setTexture:[output texture] atIndex:5];
2490
const auto& launchParams =
2491
spatialPointwiseKernelLaunchParams(state, output);
2492
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2493
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2494
[encoder endEncoding];
2495
for (auto i = 0; i < Inputs().size(); ++i) {
2496
Wrapper(i).markRead();
2499
VLOG(2) << "Concat took: " << t.MilliSeconds();
2500
outputWrapper.copyToOutputBlob(Outputs()[0]);
2505
REGISTER_CPU_OPERATOR(MPSCNNConcat, MPSCNNConcatOp);
2506
// Only store one output in practice (ignore the shape argument).
2507
OPERATOR_SCHEMA(MPSCNNConcat).NumInputs(2, 4).NumOutputs(1, 2);
2509
class MPSCNNResizeNearestOp final : public Operator<CPUContext> {
2511
MPSCNNResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
2512
: Operator<CPUContext>(operator_def, ws) {
2513
width_scale_ = OperatorBase::GetSingleArgument<float>("width_scale", 1);
2514
height_scale_ = OperatorBase::GetSingleArgument<float>("height_scale", 1);
2515
CAFFE_ENFORCE_GT(width_scale_, 0);
2516
CAFFE_ENFORCE_GT(height_scale_, 0);
2518
// due to the way we pass these parameters, we don't support the scale to be
2520
CAFFE_ENFORCE_LE(width_scale_, 6.5);
2521
CAFFE_ENFORCE_LE(height_scale_, 6.5);
2524
bool RunOnDevice() override {
2525
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
2526
const MPSImage* X = inputWrapper.getImage();
2528
const int N = X.numberOfImages, C = X.featureChannels, H = X.height,
2530
int output_width = W * width_scale_;
2531
int output_height = H * height_scale_;
2532
auto outputWrapper =
2533
MPSImageWrapper(this, &inputWrapper, N, output_height, output_width, C);
2534
auto commandBuffer = inputWrapper.getCommandBuffer();
2535
MPSImage* output = outputWrapper.getImage();
2537
auto encoder = [commandBuffer computeCommandEncoder];
2538
auto state = getMPSCNNContext().getSpecializedPipelineState(
2539
kernelFor(output, @"resize_nearest", @"resize_nearest_nonarray"),
2540
{{ushort(output_height),
2541
ushort(output_width),
2542
ushort(height_scale_ * 10000),
2543
ushort(width_scale_ * 10000)}});
2544
[encoder setComputePipelineState:state];
2545
[encoder setTexture:[X texture] atIndex:0];
2546
[encoder setTexture:[output texture] atIndex:1];
2547
auto launchParams = spatialPointwiseKernelLaunchParams(state, output);
2548
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2549
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2550
[encoder endEncoding];
2551
inputWrapper.markRead();
2552
outputWrapper.copyToOutputBlob(Outputs()[0]);
2559
float height_scale_;
2562
REGISTER_CPU_OPERATOR(MPSCNNResizeNearest, MPSCNNResizeNearestOp);
2563
OPERATOR_SCHEMA(MPSCNNResizeNearest).NumInputs(1).NumOutputs(1);
2565
class MPSCNNChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
2567
MPSCNNChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
2568
: ConvPoolOpBase<CPUContext>(operator_def, ws) {
2569
OPERATOR_NEEDS_FEATURE(
2570
this->order_ == StorageOrder::NCHW, "Metal only supports NCHW order.");
2571
kernel_[0] = kernel_[1] = 1;
2574
bool RunOnDeviceWithOrderNCHW() override {
2576
auto inputWrapper = Inputs()[0]->Get<MPSImageWrapper>();
2577
MPSImage* X = inputWrapper.getImage();
2578
CAFFE_ENFORCE_EQ(X.featureChannels % this->group_, 0);
2579
auto output_height = X.height;
2580
auto output_width = X.width;
2581
auto outputWrapper = MPSImageWrapper(
2588
auto commandBuffer = outputWrapper.getCommandBuffer();
2589
MPSImage* output = outputWrapper.getImage();
2590
CAFFE_ENFORCE_EQ(output.height, output_height);
2591
CAFFE_ENFORCE_EQ(output.width, output_width);
2592
id<MTLComputeCommandEncoder> encoder =
2593
[commandBuffer computeCommandEncoder];
2594
id<MTLComputePipelineState> state =
2595
getMPSCNNContext().getSpecializedPipelineState(
2598
ushort(X.numberOfImages),
2599
ushort(X.featureChannels),
2600
ushort(X.featureChannels / this->group_),
2601
ushort(this->group_),
2603
[encoder setComputePipelineState:state];
2604
[encoder setTexture:[X texture] atIndex:0];
2605
[encoder setTexture:[output texture] atIndex:1];
2606
const auto& launchParams =
2607
spatialPointwiseKernelLaunchParams(state, output);
2608
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
2609
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
2610
[encoder endEncoding];
2611
inputWrapper.markRead();
2612
outputWrapper.copyToOutputBlob(Outputs()[0]);
2614
VLOG(2) << "ChannelShuffle took: " << t.MilliSeconds();
2619
REGISTER_CPU_OPERATOR(MPSCNNChannelShuffle, MPSCNNChannelShuffleOp);
2620
OPERATOR_SCHEMA(MPSCNNChannelShuffle).NumInputs(1).NumOutputs(1);
2623
CAFFE_KNOWN_TYPE(MPSImageWrapper);
2624
} // namespace caffe2