onnxruntime
502 строки · 14.9 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { Guid } from 'guid-typescript';
5import Long from 'long';
6
7import { onnxruntime } from './ort-schema/flatbuffers/ort-generated';
8import { onnx } from './ort-schema/protobuf/onnx';
9import { decodeUtf8String, ProtoUtil, ShapeUtil } from './util';
10
11import ortFbs = onnxruntime.experimental.fbs;
12
13export declare namespace Tensor {
14export interface DataTypeMap {
15bool: Uint8Array;
16float32: Float32Array;
17float64: Float64Array;
18string: string[];
19int8: Int8Array;
20uint8: Uint8Array;
21int16: Int16Array;
22uint16: Uint16Array;
23int32: Int32Array;
24uint32: Uint32Array;
25int64: BigInt64Array;
26}
27
28export type DataType = keyof DataTypeMap;
29
30export type StringType = Tensor.DataTypeMap['string'];
31export type BooleanType = Tensor.DataTypeMap['bool'];
32export type IntegerType =
33| Tensor.DataTypeMap['int8']
34| Tensor.DataTypeMap['uint8']
35| Tensor.DataTypeMap['int16']
36| Tensor.DataTypeMap['uint16']
37| Tensor.DataTypeMap['int32']
38| Tensor.DataTypeMap['uint32'];
39export type FloatType = Tensor.DataTypeMap['float32'] | Tensor.DataTypeMap['float64'];
40export type NumberType = BooleanType | IntegerType | FloatType;
41
42export type Id = Guid;
43}
44
45type TensorData = Tensor.DataTypeMap[Tensor.DataType];
46
47type DataProvider = (id: Tensor.Id) => TensorData;
48type AsyncDataProvider = (id: Tensor.Id) => Promise<TensorData>;
49
50export class Tensor {
51/**
52* get the underlying tensor data
53*/
54get data(): TensorData {
55if (this.cache === undefined) {
56const data = this.dataProvider!(this.dataId);
57if (data.length !== this.size) {
58throw new Error('Length of data provided by the Data Provider is inconsistent with the dims of this Tensor.');
59}
60this.cache = data;
61}
62return this.cache;
63}
64
65/**
66* get the underlying string tensor data. Should only use when type is STRING
67*/
68get stringData() {
69if (this.type !== 'string') {
70throw new TypeError('data type is not string');
71}
72
73return this.data as Tensor.StringType;
74}
75
76/**
77* get the underlying integer tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16,
78* INT16, INT32, UINT32, BOOL)
79*/
80get integerData() {
81switch (this.type) {
82case 'uint8':
83case 'int8':
84case 'uint16':
85case 'int16':
86case 'int32':
87case 'uint32':
88case 'bool':
89return this.data as Tensor.IntegerType;
90
91default:
92throw new TypeError('data type is not integer (uint8, int8, uint16, int16, int32, uint32, bool)');
93}
94}
95
96/**
97* get the underlying float tensor data. Should only use when type is one of the following: (FLOAT, DOUBLE)
98*/
99get floatData() {
100switch (this.type) {
101case 'float32':
102case 'float64':
103return this.data as Tensor.FloatType;
104
105default:
106throw new TypeError('data type is not float (float32, float64)');
107}
108}
109
110/**
111* get the underlying number tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16,
112* INT16, INT32, UINT32, BOOL, FLOAT, DOUBLE)
113*/
114get numberData() {
115if (this.type !== 'string') {
116return this.data as Tensor.NumberType;
117}
118throw new TypeError('type cannot be non-number (string)');
119}
120
121/**
122* get value of an element at the given indices
123*/
124get(indices: readonly number[]): Tensor.DataTypeMap[Tensor.DataType][number] {
125return this.data[ShapeUtil.indicesToOffset(indices, this.strides)];
126}
127
128/**
129* set value of an element at the given indices
130*/
131set(indices: readonly number[], value: Tensor.DataTypeMap[Tensor.DataType][number]) {
132this.data[ShapeUtil.indicesToOffset(indices, this.strides)] = value;
133}
134
135/**
136* get the underlying tensor data asynchronously
137*/
138async getData(): Promise<TensorData> {
139if (this.cache === undefined) {
140this.cache = await this.asyncDataProvider!(this.dataId);
141}
142return this.cache;
143}
144
145/**
146* get the number of elements in the tensor
147*/
148public readonly size: number;
149
150private _strides: readonly number[];
151/**
152* get the strides for each dimension
153*/
154get strides(): readonly number[] {
155if (!this._strides) {
156this._strides = ShapeUtil.computeStrides(this.dims);
157}
158return this._strides;
159}
160
161constructor(
162/**
163* get the dimensions of the tensor
164*/
165public readonly dims: readonly number[],
166/**
167* get the type of the tensor
168*/
169public readonly type: Tensor.DataType,
170private dataProvider?: DataProvider,
171private asyncDataProvider?: AsyncDataProvider,
172private cache?: TensorData,
173/**
174* get the data ID that used to map to a tensor data
175*/
176public readonly dataId: Guid = Guid.create(),
177) {
178this.size = ShapeUtil.validateDimsAndCalcSize(dims);
179const size = this.size;
180const empty = dataProvider === undefined && asyncDataProvider === undefined && cache === undefined;
181
182if (cache !== undefined) {
183if (cache.length !== size) {
184throw new RangeError("Input dims doesn't match data length.");
185}
186}
187
188if (type === 'string') {
189if (cache !== undefined && (!Array.isArray(cache) || !cache.every((i) => typeof i === 'string'))) {
190throw new TypeError('cache should be a string array');
191}
192
193if (empty) {
194this.cache = new Array<string>(size);
195}
196} else {
197if (cache !== undefined) {
198const constructor = dataviewConstructor(type);
199if (!(cache instanceof constructor)) {
200throw new TypeError(`cache should be type ${constructor.name}`);
201}
202}
203
204if (empty) {
205const buf = new ArrayBuffer(size * sizeof(type));
206this.cache = createView(buf, type);
207}
208}
209}
210
211/**
212* Construct new Tensor from a ONNX Tensor object
213* @param tensorProto the ONNX Tensor
214*/
215static fromProto(tensorProto: onnx.ITensorProto): Tensor {
216if (!tensorProto) {
217throw new Error('cannot construct Value from an empty tensor');
218}
219const type = ProtoUtil.tensorDataTypeFromProto(tensorProto.dataType!);
220const dims = ProtoUtil.tensorDimsFromProto(tensorProto.dims!);
221
222const value = new Tensor(dims, type);
223
224if (type === 'string') {
225// When it's STRING type, the value should always be stored in field
226// 'stringData'
227tensorProto.stringData!.forEach((str, i) => {
228value.data[i] = decodeUtf8String(str);
229});
230} else if (
231tensorProto.rawData &&
232typeof tensorProto.rawData.byteLength === 'number' &&
233tensorProto.rawData.byteLength > 0
234) {
235// NOT considering segment for now (IMPORTANT)
236
237// populate value from rawData
238const dataDest = value.data;
239const dataSource = new DataView(
240tensorProto.rawData.buffer,
241tensorProto.rawData.byteOffset,
242tensorProto.rawData.byteLength,
243);
244const elementSize = sizeofProto(tensorProto.dataType!);
245const length = tensorProto.rawData.byteLength / elementSize;
246
247if (tensorProto.rawData.byteLength % elementSize !== 0) {
248throw new Error('invalid buffer length');
249}
250if (dataDest.length !== length) {
251throw new Error('buffer length mismatch');
252}
253
254for (let i = 0; i < length; i++) {
255const n = readProto(dataSource, tensorProto.dataType!, i * elementSize);
256dataDest[i] = n;
257}
258} else {
259// populate value from array
260let array: Array<number | Long>;
261switch (tensorProto.dataType) {
262case onnx.TensorProto.DataType.FLOAT:
263array = tensorProto.floatData!;
264break;
265case onnx.TensorProto.DataType.INT32:
266case onnx.TensorProto.DataType.INT16:
267case onnx.TensorProto.DataType.UINT16:
268case onnx.TensorProto.DataType.INT8:
269case onnx.TensorProto.DataType.UINT8:
270case onnx.TensorProto.DataType.BOOL:
271array = tensorProto.int32Data!;
272break;
273case onnx.TensorProto.DataType.INT64:
274array = tensorProto.int64Data!;
275break;
276case onnx.TensorProto.DataType.DOUBLE:
277array = tensorProto.doubleData!;
278break;
279case onnx.TensorProto.DataType.UINT32:
280case onnx.TensorProto.DataType.UINT64:
281array = tensorProto.uint64Data!;
282break;
283default:
284// should never run here
285throw new Error('unspecific error');
286}
287
288if (array === null || array === undefined) {
289throw new Error('failed to populate data from a tensorproto value');
290}
291
292const data = value.data;
293if (data.length !== array.length) {
294throw new Error('array length mismatch');
295}
296
297for (let i = 0; i < array.length; i++) {
298const element = array[i];
299if (Long.isLong(element)) {
300data[i] = longToNumber(element, tensorProto.dataType);
301} else {
302data[i] = element;
303}
304}
305}
306
307return value;
308}
309
310/**
311* Construct new Tensor from raw data
312* @param data the raw data object. Should be a string array for 'string' tensor, and the corresponding typed array
313* for other types of tensor.
314* @param dims the dimensions of the tensor
315* @param type the type of the tensor
316*/
317static fromData(data: Tensor.DataTypeMap[Tensor.DataType], dims: readonly number[], type: Tensor.DataType) {
318return new Tensor(dims, type, undefined, undefined, data);
319}
320
321static fromOrtTensor(ortTensor: ortFbs.Tensor) {
322if (!ortTensor) {
323throw new Error('cannot construct Value from an empty tensor');
324}
325const dims = ProtoUtil.tensorDimsFromORTFormat(ortTensor);
326const type = ProtoUtil.tensorDataTypeFromProto(ortTensor.dataType());
327
328const value = new Tensor(dims, type);
329
330if (type === 'string') {
331// When it's STRING type, the value should always be stored in field
332// 'stringData'
333for (let i = 0; i < ortTensor.stringDataLength(); i++) {
334value.data[i] = ortTensor.stringData(i);
335}
336} else if (
337ortTensor.rawDataArray() &&
338typeof ortTensor.rawDataLength() === 'number' &&
339ortTensor.rawDataLength() > 0
340) {
341// NOT considering segment for now (IMPORTANT)
342
343// populate value from rawData
344const dataDest = value.data;
345const dataSource = new DataView(
346ortTensor.rawDataArray()!.buffer,
347ortTensor.rawDataArray()!.byteOffset,
348ortTensor.rawDataLength(),
349);
350const elementSize = sizeofProto(ortTensor.dataType());
351const length = ortTensor.rawDataLength() / elementSize;
352
353if (ortTensor.rawDataLength() % elementSize !== 0) {
354throw new Error('invalid buffer length');
355}
356if (dataDest.length !== length) {
357throw new Error('buffer length mismatch');
358}
359
360for (let i = 0; i < length; i++) {
361const n = readProto(dataSource, ortTensor.dataType(), i * elementSize);
362dataDest[i] = n;
363}
364}
365return value;
366}
367}
368
369function sizeof(type: Tensor.DataType): number {
370switch (type) {
371case 'bool':
372case 'int8':
373case 'uint8':
374return 1;
375case 'int16':
376case 'uint16':
377return 2;
378case 'int32':
379case 'uint32':
380case 'float32':
381return 4;
382case 'float64':
383return 8;
384default:
385throw new Error(`cannot calculate sizeof() on type ${type}`);
386}
387}
388
389function sizeofProto(type: onnx.TensorProto.DataType | ortFbs.TensorDataType): number {
390switch (type) {
391case onnx.TensorProto.DataType.UINT8:
392case onnx.TensorProto.DataType.INT8:
393case onnx.TensorProto.DataType.BOOL:
394return 1;
395case onnx.TensorProto.DataType.UINT16:
396case onnx.TensorProto.DataType.INT16:
397return 2;
398case onnx.TensorProto.DataType.FLOAT:
399case onnx.TensorProto.DataType.INT32:
400case onnx.TensorProto.DataType.UINT32:
401return 4;
402case onnx.TensorProto.DataType.INT64:
403case onnx.TensorProto.DataType.DOUBLE:
404case onnx.TensorProto.DataType.UINT64:
405return 8;
406default:
407throw new Error(`cannot calculate sizeof() on type ${onnx.TensorProto.DataType[type]}`);
408}
409}
410
411function createView(dataBuffer: ArrayBuffer, type: Tensor.DataType) {
412return new (dataviewConstructor(type))(dataBuffer);
413}
414
415function dataviewConstructor(type: Tensor.DataType) {
416switch (type) {
417case 'bool':
418case 'uint8':
419return Uint8Array;
420case 'int8':
421return Int8Array;
422case 'int16':
423return Int16Array;
424case 'uint16':
425return Uint16Array;
426case 'int32':
427return Int32Array;
428case 'uint32':
429return Uint32Array;
430case 'int64':
431return BigInt64Array;
432case 'float32':
433return Float32Array;
434case 'float64':
435return Float64Array;
436default:
437// should never run to here
438throw new Error('unspecified error');
439}
440}
441
442// convert a long number to a 32-bit integer (cast-down)
443function longToNumber(i: Long, type: onnx.TensorProto.DataType | ortFbs.TensorDataType): number {
444// INT64, UINT32, UINT64
445if (type === onnx.TensorProto.DataType.INT64 || type === ortFbs.TensorDataType.INT64) {
446if (i.greaterThanOrEqual(2147483648) || i.lessThan(-2147483648)) {
447throw new TypeError('int64 is not supported');
448}
449} else if (
450type === onnx.TensorProto.DataType.UINT32 ||
451type === ortFbs.TensorDataType.UINT32 ||
452type === onnx.TensorProto.DataType.UINT64 ||
453type === ortFbs.TensorDataType.UINT64
454) {
455if (i.greaterThanOrEqual(4294967296) || i.lessThan(0)) {
456throw new TypeError('uint64 is not supported');
457}
458} else {
459throw new TypeError(`not a LONG type: ${onnx.TensorProto.DataType[type]}`);
460}
461
462return i.toNumber();
463}
464
465// read one value from TensorProto
466function readProto(
467view: DataView,
468type: onnx.TensorProto.DataType | ortFbs.TensorDataType,
469byteOffset: number,
470): number {
471switch (type) {
472case onnx.TensorProto.DataType.BOOL:
473case onnx.TensorProto.DataType.UINT8:
474return view.getUint8(byteOffset);
475case onnx.TensorProto.DataType.INT8:
476return view.getInt8(byteOffset);
477case onnx.TensorProto.DataType.UINT16:
478return view.getUint16(byteOffset, true);
479case onnx.TensorProto.DataType.INT16:
480return view.getInt16(byteOffset, true);
481case onnx.TensorProto.DataType.FLOAT:
482return view.getFloat32(byteOffset, true);
483case onnx.TensorProto.DataType.INT32:
484return view.getInt32(byteOffset, true);
485case onnx.TensorProto.DataType.UINT32:
486return view.getUint32(byteOffset, true);
487case onnx.TensorProto.DataType.INT64:
488return longToNumber(
489Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), false),
490type,
491);
492case onnx.TensorProto.DataType.DOUBLE:
493return view.getFloat64(byteOffset, true);
494case onnx.TensorProto.DataType.UINT64:
495return longToNumber(
496Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), true),
497type,
498);
499default:
500throw new Error(`cannot read from DataView for type ${onnx.TensorProto.DataType[type]}`);
501}
502}
503