onnxruntime
246 строк · 8.6 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#import "ort_value_internal.h"
5
6#include <optional>
7
8#import "cxx_api.h"
9#import "error_utils.h"
10#import "ort_enums_internal.h"
11
12NS_ASSUME_NONNULL_BEGIN
13
14namespace {
15
16ORTTensorTypeAndShapeInfo* CXXAPIToPublicTensorTypeAndShapeInfo(
17const Ort::ConstTensorTypeAndShapeInfo& CXXAPITensorTypeAndShapeInfo) {
18auto* result = [[ORTTensorTypeAndShapeInfo alloc] init];
19const auto elementType = CXXAPITensorTypeAndShapeInfo.GetElementType();
20const std::vector<int64_t> shape = CXXAPITensorTypeAndShapeInfo.GetShape();
21
22result.elementType = CAPIToPublicTensorElementType(elementType);
23auto* shapeArray = [[NSMutableArray alloc] initWithCapacity:shape.size()];
24for (size_t i = 0; i < shape.size(); ++i) {
25shapeArray[i] = @(shape[i]);
26}
27result.shape = shapeArray;
28
29return result;
30}
31
32ORTValueTypeInfo* CXXAPIToPublicValueTypeInfo(
33const Ort::TypeInfo& CXXAPITypeInfo) {
34auto* result = [[ORTValueTypeInfo alloc] init];
35const auto valueType = CXXAPITypeInfo.GetONNXType();
36
37result.type = CAPIToPublicValueType(valueType);
38
39if (valueType == ONNX_TYPE_TENSOR) {
40const auto tensorTypeAndShapeInfo = CXXAPITypeInfo.GetTensorTypeAndShapeInfo();
41result.tensorTypeAndShapeInfo = CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
42}
43
44return result;
45}
46
47// out = a * b
48// returns true iff the result does not overflow
49bool SafeMultiply(size_t a, size_t b, size_t& out) {
50return !__builtin_mul_overflow(a, b, &out);
51}
52
53} // namespace
54
55@interface ORTValue ()
56
57// pointer to any external tensor data to keep alive for the lifetime of the ORTValue
58@property(nonatomic, nullable) NSMutableData* externalTensorData;
59
60@end
61
62@implementation ORTValue {
63std::optional<Ort::Value> _value;
64std::optional<Ort::TypeInfo> _typeInfo;
65}
66
67#pragma mark - Public
68
69- (nullable instancetype)initWithTensorData:(NSMutableData*)tensorData
70elementType:(ORTTensorElementDataType)elementType
71shape:(NSArray<NSNumber*>*)shape
72error:(NSError**)error {
73try {
74if (elementType == ORTTensorElementDataTypeString) {
75ORT_CXX_API_THROW(
76"ORTTensorElementDataTypeString element type provided. "
77"Please call initWithTensorStringData:shape:error: instead to create an ORTValue with string data.",
78ORT_INVALID_ARGUMENT);
79}
80const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
81const auto ONNXElementType = PublicToCAPITensorElementType(elementType);
82const auto shapeVector = [shape]() {
83std::vector<int64_t> result{};
84result.reserve(shape.count);
85for (NSNumber* dim in shape) {
86result.push_back(dim.longLongValue);
87}
88return result;
89}();
90Ort::Value ortValue = Ort::Value::CreateTensor(
91memoryInfo, tensorData.mutableBytes, tensorData.length,
92shapeVector.data(), shapeVector.size(), ONNXElementType);
93
94return [self initWithCXXAPIOrtValue:std::move(ortValue)
95externalTensorData:tensorData
96error:error];
97}
98ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
99}
100
101- (nullable instancetype)initWithTensorStringData:(NSArray<NSString*>*)tensorStringData
102shape:(NSArray<NSNumber*>*)shape
103error:(NSError**)error {
104try {
105Ort::AllocatorWithDefaultOptions allocator;
106size_t tensorSize = 1U;
107const auto shapeVector = [&tensorSize, shape]() {
108std::vector<int64_t> result{};
109result.reserve(shape.count);
110for (NSNumber* dim in shape) {
111const auto dimValue = dim.longLongValue;
112if (dimValue < 0 || !SafeMultiply(static_cast<size_t>(dimValue), tensorSize, tensorSize)) {
113ORT_CXX_API_THROW("Failed to compute the tensor size.", ORT_RUNTIME_EXCEPTION);
114}
115result.push_back(dimValue);
116}
117return result;
118}();
119
120if (tensorSize != [tensorStringData count]) {
121ORT_CXX_API_THROW(
122"Computed tensor size does not equal the length of the provided tensor string data.",
123ORT_INVALID_ARGUMENT);
124}
125
126Ort::Value ortValue = Ort::Value::CreateTensor(
127allocator, shapeVector.data(), shapeVector.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
128
129size_t index = 0;
130for (NSString* stringData in tensorStringData) {
131ortValue.FillStringTensorElement([stringData UTF8String], index++);
132}
133
134return [self initWithCXXAPIOrtValue:std::move(ortValue)
135externalTensorData:nil
136error:error];
137}
138ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
139}
140
141- (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error {
142try {
143return CXXAPIToPublicValueTypeInfo(*_typeInfo);
144}
145ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
146}
147
148- (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error {
149try {
150const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
151if (!tensorTypeAndShapeInfo) {
152ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
153}
154return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
155}
156ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
157}
158
159- (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
160try {
161const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
162if (!tensorTypeAndShapeInfo) {
163ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
164}
165if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
166ORT_CXX_API_THROW(
167"This ORTValue holds string data. Please call tensorStringDataWithError: "
168"instead to retrieve the string data from this ORTValue.",
169ORT_RUNTIME_EXCEPTION);
170}
171const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
172const size_t elementSize = SizeOfCAPITensorElementType(tensorTypeAndShapeInfo.GetElementType());
173size_t rawDataLength;
174if (!SafeMultiply(elementCount, elementSize, rawDataLength)) {
175ORT_CXX_API_THROW("failed to compute tensor data length", ORT_RUNTIME_EXCEPTION);
176}
177
178void* rawData;
179Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(*_value, &rawData));
180
181return [NSMutableData dataWithBytesNoCopy:rawData
182length:rawDataLength
183freeWhenDone:NO];
184}
185ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
186}
187
188- (nullable NSArray<NSString*>*)tensorStringDataWithError:(NSError**)error {
189try {
190const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
191if (!tensorTypeAndShapeInfo) {
192ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
193}
194const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
195const size_t tensorStringDataLength = _value->GetStringTensorDataLength();
196std::vector<char> tensorStringData(tensorStringDataLength, '\0');
197std::vector<size_t> offsets(elementCount);
198_value->GetStringTensorContent(tensorStringData.data(), tensorStringDataLength,
199offsets.data(), offsets.size());
200
201NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:elementCount];
202for (size_t idx = 0; idx < elementCount; ++idx) {
203const size_t strLength = (idx == elementCount - 1) ? tensorStringDataLength - offsets[idx]
204: offsets[idx + 1] - offsets[idx];
205[result addObject:[[NSString alloc] initWithBytes:tensorStringData.data() + offsets[idx]
206length:strLength
207encoding:NSUTF8StringEncoding]];
208}
209return result;
210}
211ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
212}
213
214#pragma mark - Internal
215
216- (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue
217externalTensorData:(nullable NSMutableData*)externalTensorData
218error:(NSError**)error {
219if ((self = [super init]) == nil) {
220return nil;
221}
222
223try {
224_typeInfo = existingCXXAPIOrtValue.GetTypeInfo();
225_externalTensorData = externalTensorData;
226
227// transfer C++ Ort::Value ownership to this instance
228_value = std::move(existingCXXAPIOrtValue);
229return self;
230}
231ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error);
232}
233
234- (Ort::Value&)CXXAPIOrtValue {
235return *_value;
236}
237
238@end
239
240@implementation ORTValueTypeInfo
241@end
242
243@implementation ORTTensorTypeAndShapeInfo
244@end
245
246NS_ASSUME_NONNULL_END
247