pytorch

Форк
0
/
mpscnn_context.mm 
101 строка · 3.2 Кб
1

2
#include "caffe2/core/common.h"
3

4
#ifdef C10_MOBILE
5

6
#include "mpscnn_context.h"
7
#include "mpscnn_kernels.h"
8

9
#include "caffe2/core/logging.h"
10
#include "caffe2/core/timer.h"
11

12
#include <array>
13
#include <mutex>
14
#include <thread>
15

16
#import <Metal/MTLFunctionConstantValues.h>
17

18
namespace caffe2 {
19

20
MPSCNNContext& getMPSCNNContext() {
21
  static std::once_flag once;
22
  static MPSCNNContext ctx;
23
  std::call_once(once, []() {
24
    NSError* compileError = nil;
25
    ctx.device = MTLCreateSystemDefaultDevice();
26
    ctx.library = [ctx.device newLibraryWithSource:[NSString stringWithUTF8String:MPSCNN_KERNELS]
27
                                           options:nil
28
                                             error:&compileError];
29
    if (compileError != nil || ctx.library == nil) {
30
      CAFFE_THROW("Failed to load kernels: ", [[compileError localizedDescription] UTF8String]);
31
    }
32
    ctx.commandQueue = [ctx.device newCommandQueue];
33
  });
34
  return ctx;
35
}
36

37
id<MTLComputePipelineState> MPSCNNContext::getPipelineState(NSString* kernel) {
38
  std::string kernelStr = std::string([kernel UTF8String]);
39
  std::lock_guard<std::mutex> g(pipelineCacheMutex_);
40
  if (pipelineCache_.find(kernelStr) != pipelineCache_.end()) {
41
    VLOG(1) << "Hit in pipeline cache for: " << kernelStr;
42
    return pipelineCache_[kernelStr];
43
  }
44
  LOG(INFO) << "Miss in pipeline cache for: " << kernelStr;
45
  id<MTLFunction> func = [library newFunctionWithName:kernel];
46
  if (!func) {
47
    CAFFE_THROW("Couldn't get function: ", kernelStr);
48
    return nullptr;
49
  }
50
  NSError* errors;
51
  id<MTLComputePipelineState> state =
52
      [device newComputePipelineStateWithFunction:func error:&errors];
53
  if (!state) {
54
    CAFFE_THROW("Couldn't get state: ", kernelStr);
55
    return nullptr;
56
  }
57
  pipelineCache_[kernelStr] = state;
58
  return state;
59
}
60

61
id<MTLComputePipelineState> MPSCNNContext::getSpecializedPipelineState(
62
    NSString* kernel, const std::vector<ushort>& constants) {
63
  std::string kernelStr = std::string([kernel UTF8String]);
64
  for (auto i = 0; i < constants.size(); ++i) {
65
    kernelStr += "_" + std::to_string(constants[i]);
66
  }
67
  std::lock_guard<std::mutex> g(pipelineCacheMutex_);
68
  if (pipelineCache_.find(kernelStr) != pipelineCache_.end()) {
69
    VLOG(1) << "Hit in pipeline cache for: " << kernelStr;
70
    return pipelineCache_[kernelStr];
71
  }
72
  MTLFunctionConstantValues* constantValues = [MTLFunctionConstantValues new];
73
  for (auto i = 0; i < constants.size(); ++i) {
74
    [constantValues setConstantValue:&constants[i] type:MTLDataTypeUShort atIndex:i];
75
  }
76
  NSError* errors;
77

78
  LOG(INFO) << "Miss in pipeline cache for: " << kernelStr;
79
  id<MTLFunction> func =
80
      [library newFunctionWithName:kernel constantValues:constantValues error:&errors];
81
  if (!func) {
82
    CAFFE_THROW("Couldn't get function: ",
83
                kernelStr,
84
                " error: ",
85
                [[errors localizedDescription] UTF8String]);
86
    return nullptr;
87
  }
88
  id<MTLComputePipelineState> state =
89
      [device newComputePipelineStateWithFunction:func error:&errors];
90
  if (!state) {
91
    CAFFE_THROW("Couldn't get function: ",
92
                kernelStr,
93
                " error: ",
94
                [[errors localizedDescription] UTF8String]);
95
    return nullptr;
96
  }
97
  pipelineCache_[kernelStr] = state;
98
  return state;
99
}
100
}
101

102
#endif
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.