onnxruntime

Форк
0
/
ort_value.mm 
246 строк · 8.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
#import "ort_value_internal.h"
5

6
#include <optional>
7

8
#import "cxx_api.h"
9
#import "error_utils.h"
10
#import "ort_enums_internal.h"
11

12
NS_ASSUME_NONNULL_BEGIN
13

14
namespace {
15

16
ORTTensorTypeAndShapeInfo* CXXAPIToPublicTensorTypeAndShapeInfo(
17
    const Ort::ConstTensorTypeAndShapeInfo& CXXAPITensorTypeAndShapeInfo) {
18
  auto* result = [[ORTTensorTypeAndShapeInfo alloc] init];
19
  const auto elementType = CXXAPITensorTypeAndShapeInfo.GetElementType();
20
  const std::vector<int64_t> shape = CXXAPITensorTypeAndShapeInfo.GetShape();
21

22
  result.elementType = CAPIToPublicTensorElementType(elementType);
23
  auto* shapeArray = [[NSMutableArray alloc] initWithCapacity:shape.size()];
24
  for (size_t i = 0; i < shape.size(); ++i) {
25
    shapeArray[i] = @(shape[i]);
26
  }
27
  result.shape = shapeArray;
28

29
  return result;
30
}
31

32
ORTValueTypeInfo* CXXAPIToPublicValueTypeInfo(
33
    const Ort::TypeInfo& CXXAPITypeInfo) {
34
  auto* result = [[ORTValueTypeInfo alloc] init];
35
  const auto valueType = CXXAPITypeInfo.GetONNXType();
36

37
  result.type = CAPIToPublicValueType(valueType);
38

39
  if (valueType == ONNX_TYPE_TENSOR) {
40
    const auto tensorTypeAndShapeInfo = CXXAPITypeInfo.GetTensorTypeAndShapeInfo();
41
    result.tensorTypeAndShapeInfo = CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
42
  }
43

44
  return result;
45
}
46

47
// out = a * b
48
// returns true iff the result does not overflow
49
bool SafeMultiply(size_t a, size_t b, size_t& out) {
50
  return !__builtin_mul_overflow(a, b, &out);
51
}
52

53
}  // namespace
54

55
@interface ORTValue ()
56

57
// pointer to any external tensor data to keep alive for the lifetime of the ORTValue
58
@property(nonatomic, nullable) NSMutableData* externalTensorData;
59

60
@end
61

62
@implementation ORTValue {
63
  std::optional<Ort::Value> _value;
64
  std::optional<Ort::TypeInfo> _typeInfo;
65
}
66

67
#pragma mark - Public
68

69
- (nullable instancetype)initWithTensorData:(NSMutableData*)tensorData
70
                                elementType:(ORTTensorElementDataType)elementType
71
                                      shape:(NSArray<NSNumber*>*)shape
72
                                      error:(NSError**)error {
73
  try {
74
    if (elementType == ORTTensorElementDataTypeString) {
75
      ORT_CXX_API_THROW(
76
          "ORTTensorElementDataTypeString element type provided. "
77
          "Please call initWithTensorStringData:shape:error: instead to create an ORTValue with string data.",
78
          ORT_INVALID_ARGUMENT);
79
    }
80
    const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
81
    const auto ONNXElementType = PublicToCAPITensorElementType(elementType);
82
    const auto shapeVector = [shape]() {
83
      std::vector<int64_t> result{};
84
      result.reserve(shape.count);
85
      for (NSNumber* dim in shape) {
86
        result.push_back(dim.longLongValue);
87
      }
88
      return result;
89
    }();
90
    Ort::Value ortValue = Ort::Value::CreateTensor(
91
        memoryInfo, tensorData.mutableBytes, tensorData.length,
92
        shapeVector.data(), shapeVector.size(), ONNXElementType);
93

94
    return [self initWithCXXAPIOrtValue:std::move(ortValue)
95
                     externalTensorData:tensorData
96
                                  error:error];
97
  }
98
  ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
99
}
100

101
- (nullable instancetype)initWithTensorStringData:(NSArray<NSString*>*)tensorStringData
102
                                            shape:(NSArray<NSNumber*>*)shape
103
                                            error:(NSError**)error {
104
  try {
105
    Ort::AllocatorWithDefaultOptions allocator;
106
    size_t tensorSize = 1U;
107
    const auto shapeVector = [&tensorSize, shape]() {
108
      std::vector<int64_t> result{};
109
      result.reserve(shape.count);
110
      for (NSNumber* dim in shape) {
111
        const auto dimValue = dim.longLongValue;
112
        if (dimValue < 0 || !SafeMultiply(static_cast<size_t>(dimValue), tensorSize, tensorSize)) {
113
          ORT_CXX_API_THROW("Failed to compute the tensor size.", ORT_RUNTIME_EXCEPTION);
114
        }
115
        result.push_back(dimValue);
116
      }
117
      return result;
118
    }();
119

120
    if (tensorSize != [tensorStringData count]) {
121
      ORT_CXX_API_THROW(
122
          "Computed tensor size does not equal the length of the provided tensor string data.",
123
          ORT_INVALID_ARGUMENT);
124
    }
125

126
    Ort::Value ortValue = Ort::Value::CreateTensor(
127
        allocator, shapeVector.data(), shapeVector.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
128

129
    size_t index = 0;
130
    for (NSString* stringData in tensorStringData) {
131
      ortValue.FillStringTensorElement([stringData UTF8String], index++);
132
    }
133

134
    return [self initWithCXXAPIOrtValue:std::move(ortValue)
135
                     externalTensorData:nil
136
                                  error:error];
137
  }
138
  ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
139
}
140

141
- (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error {
142
  try {
143
    return CXXAPIToPublicValueTypeInfo(*_typeInfo);
144
  }
145
  ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
146
}
147

148
- (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error {
149
  try {
150
    const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
151
    if (!tensorTypeAndShapeInfo) {
152
      ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
153
    }
154
    return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
155
  }
156
  ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
157
}
158

159
- (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
160
  try {
161
    const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
162
    if (!tensorTypeAndShapeInfo) {
163
      ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
164
    }
165
    if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
166
      ORT_CXX_API_THROW(
167
          "This ORTValue holds string data. Please call tensorStringDataWithError: "
168
          "instead to retrieve the string data from this ORTValue.",
169
          ORT_RUNTIME_EXCEPTION);
170
    }
171
    const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
172
    const size_t elementSize = SizeOfCAPITensorElementType(tensorTypeAndShapeInfo.GetElementType());
173
    size_t rawDataLength;
174
    if (!SafeMultiply(elementCount, elementSize, rawDataLength)) {
175
      ORT_CXX_API_THROW("failed to compute tensor data length", ORT_RUNTIME_EXCEPTION);
176
    }
177

178
    void* rawData;
179
    Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(*_value, &rawData));
180

181
    return [NSMutableData dataWithBytesNoCopy:rawData
182
                                       length:rawDataLength
183
                                 freeWhenDone:NO];
184
  }
185
  ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
186
}
187

188
- (nullable NSArray<NSString*>*)tensorStringDataWithError:(NSError**)error {
189
  try {
190
    const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
191
    if (!tensorTypeAndShapeInfo) {
192
      ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
193
    }
194
    const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
195
    const size_t tensorStringDataLength = _value->GetStringTensorDataLength();
196
    std::vector<char> tensorStringData(tensorStringDataLength, '\0');
197
    std::vector<size_t> offsets(elementCount);
198
    _value->GetStringTensorContent(tensorStringData.data(), tensorStringDataLength,
199
                                   offsets.data(), offsets.size());
200

201
    NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:elementCount];
202
    for (size_t idx = 0; idx < elementCount; ++idx) {
203
      const size_t strLength = (idx == elementCount - 1) ? tensorStringDataLength - offsets[idx]
204
                                                         : offsets[idx + 1] - offsets[idx];
205
      [result addObject:[[NSString alloc] initWithBytes:tensorStringData.data() + offsets[idx]
206
                                                 length:strLength
207
                                               encoding:NSUTF8StringEncoding]];
208
    }
209
    return result;
210
  }
211
  ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
212
}
213

214
#pragma mark - Internal
215

216
- (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue
217
                             externalTensorData:(nullable NSMutableData*)externalTensorData
218
                                          error:(NSError**)error {
219
  if ((self = [super init]) == nil) {
220
    return nil;
221
  }
222

223
  try {
224
    _typeInfo = existingCXXAPIOrtValue.GetTypeInfo();
225
    _externalTensorData = externalTensorData;
226

227
    // transfer C++ Ort::Value ownership to this instance
228
    _value = std::move(existingCXXAPIOrtValue);
229
    return self;
230
  }
231
  ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error);
232
}
233

234
- (Ort::Value&)CXXAPIOrtValue {
235
  return *_value;
236
}
237

238
@end
239

240
@implementation ORTValueTypeInfo
241
@end
242

243
@implementation ORTTensorTypeAndShapeInfo
244
@end
245

246
NS_ASSUME_NONNULL_END
247

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.