2
#include "caffe2/core/common.h"
6
#include "mpscnn_context.h"
7
#include "mpscnn_kernels.h"
9
#include "caffe2/core/logging.h"
10
#include "caffe2/core/timer.h"
16
#import <Metal/MTLFunctionConstantValues.h>
20
MPSCNNContext& getMPSCNNContext() {
21
static std::once_flag once;
22
static MPSCNNContext ctx;
23
std::call_once(once, []() {
24
NSError* compileError = nil;
25
ctx.device = MTLCreateSystemDefaultDevice();
26
ctx.library = [ctx.device newLibraryWithSource:[NSString stringWithUTF8String:MPSCNN_KERNELS]
29
if (compileError != nil || ctx.library == nil) {
30
CAFFE_THROW("Failed to load kernels: ", [[compileError localizedDescription] UTF8String]);
32
ctx.commandQueue = [ctx.device newCommandQueue];
37
id<MTLComputePipelineState> MPSCNNContext::getPipelineState(NSString* kernel) {
38
std::string kernelStr = std::string([kernel UTF8String]);
39
std::lock_guard<std::mutex> g(pipelineCacheMutex_);
40
if (pipelineCache_.find(kernelStr) != pipelineCache_.end()) {
41
VLOG(1) << "Hit in pipeline cache for: " << kernelStr;
42
return pipelineCache_[kernelStr];
44
LOG(INFO) << "Miss in pipeline cache for: " << kernelStr;
45
id<MTLFunction> func = [library newFunctionWithName:kernel];
47
CAFFE_THROW("Couldn't get function: ", kernelStr);
51
id<MTLComputePipelineState> state =
52
[device newComputePipelineStateWithFunction:func error:&errors];
54
CAFFE_THROW("Couldn't get state: ", kernelStr);
57
pipelineCache_[kernelStr] = state;
61
id<MTLComputePipelineState> MPSCNNContext::getSpecializedPipelineState(
62
NSString* kernel, const std::vector<ushort>& constants) {
63
std::string kernelStr = std::string([kernel UTF8String]);
64
for (auto i = 0; i < constants.size(); ++i) {
65
kernelStr += "_" + std::to_string(constants[i]);
67
std::lock_guard<std::mutex> g(pipelineCacheMutex_);
68
if (pipelineCache_.find(kernelStr) != pipelineCache_.end()) {
69
VLOG(1) << "Hit in pipeline cache for: " << kernelStr;
70
return pipelineCache_[kernelStr];
72
MTLFunctionConstantValues* constantValues = [MTLFunctionConstantValues new];
73
for (auto i = 0; i < constants.size(); ++i) {
74
[constantValues setConstantValue:&constants[i] type:MTLDataTypeUShort atIndex:i];
78
LOG(INFO) << "Miss in pipeline cache for: " << kernelStr;
79
id<MTLFunction> func =
80
[library newFunctionWithName:kernel constantValues:constantValues error:&errors];
82
CAFFE_THROW("Couldn't get function: ",
85
[[errors localizedDescription] UTF8String]);
88
id<MTLComputePipelineState> state =
89
[device newComputePipelineStateWithFunction:func error:&errors];
91
CAFFE_THROW("Couldn't get function: ",
94
[[errors localizedDescription] UTF8String]);
97
pipelineCache_[kernelStr] = state;