onnxruntime

Форк
0
502 строки · 14.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Guid } from 'guid-typescript';
5
import Long from 'long';
6

7
import { onnxruntime } from './ort-schema/flatbuffers/ort-generated';
8
import { onnx } from './ort-schema/protobuf/onnx';
9
import { decodeUtf8String, ProtoUtil, ShapeUtil } from './util';
10

11
import ortFbs = onnxruntime.experimental.fbs;
12

13
export declare namespace Tensor {
14
  export interface DataTypeMap {
15
    bool: Uint8Array;
16
    float32: Float32Array;
17
    float64: Float64Array;
18
    string: string[];
19
    int8: Int8Array;
20
    uint8: Uint8Array;
21
    int16: Int16Array;
22
    uint16: Uint16Array;
23
    int32: Int32Array;
24
    uint32: Uint32Array;
25
    int64: BigInt64Array;
26
  }
27

28
  export type DataType = keyof DataTypeMap;
29

30
  export type StringType = Tensor.DataTypeMap['string'];
31
  export type BooleanType = Tensor.DataTypeMap['bool'];
32
  export type IntegerType =
33
    | Tensor.DataTypeMap['int8']
34
    | Tensor.DataTypeMap['uint8']
35
    | Tensor.DataTypeMap['int16']
36
    | Tensor.DataTypeMap['uint16']
37
    | Tensor.DataTypeMap['int32']
38
    | Tensor.DataTypeMap['uint32'];
39
  export type FloatType = Tensor.DataTypeMap['float32'] | Tensor.DataTypeMap['float64'];
40
  export type NumberType = BooleanType | IntegerType | FloatType;
41

42
  export type Id = Guid;
43
}
44

45
type TensorData = Tensor.DataTypeMap[Tensor.DataType];
46

47
type DataProvider = (id: Tensor.Id) => TensorData;
48
type AsyncDataProvider = (id: Tensor.Id) => Promise<TensorData>;
49

50
export class Tensor {
51
  /**
52
   * get the underlying tensor data
53
   */
54
  get data(): TensorData {
55
    if (this.cache === undefined) {
56
      const data = this.dataProvider!(this.dataId);
57
      if (data.length !== this.size) {
58
        throw new Error('Length of data provided by the Data Provider is inconsistent with the dims of this Tensor.');
59
      }
60
      this.cache = data;
61
    }
62
    return this.cache;
63
  }
64

65
  /**
66
   * get the underlying string tensor data. Should only use when type is STRING
67
   */
68
  get stringData() {
69
    if (this.type !== 'string') {
70
      throw new TypeError('data type is not string');
71
    }
72

73
    return this.data as Tensor.StringType;
74
  }
75

76
  /**
77
   * get the underlying integer tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16,
78
   * INT16, INT32, UINT32, BOOL)
79
   */
80
  get integerData() {
81
    switch (this.type) {
82
      case 'uint8':
83
      case 'int8':
84
      case 'uint16':
85
      case 'int16':
86
      case 'int32':
87
      case 'uint32':
88
      case 'bool':
89
        return this.data as Tensor.IntegerType;
90

91
      default:
92
        throw new TypeError('data type is not integer (uint8, int8, uint16, int16, int32, uint32, bool)');
93
    }
94
  }
95

96
  /**
97
   * get the underlying float tensor data. Should only use when type is one of the following: (FLOAT, DOUBLE)
98
   */
99
  get floatData() {
100
    switch (this.type) {
101
      case 'float32':
102
      case 'float64':
103
        return this.data as Tensor.FloatType;
104

105
      default:
106
        throw new TypeError('data type is not float (float32, float64)');
107
    }
108
  }
109

110
  /**
111
   * get the underlying number tensor data. Should only use when type is one of the following: (UINT8, INT8, UINT16,
112
   * INT16, INT32, UINT32, BOOL, FLOAT, DOUBLE)
113
   */
114
  get numberData() {
115
    if (this.type !== 'string') {
116
      return this.data as Tensor.NumberType;
117
    }
118
    throw new TypeError('type cannot be non-number (string)');
119
  }
120

121
  /**
122
   * get value of an element at the given indices
123
   */
124
  get(indices: readonly number[]): Tensor.DataTypeMap[Tensor.DataType][number] {
125
    return this.data[ShapeUtil.indicesToOffset(indices, this.strides)];
126
  }
127

128
  /**
129
   * set value of an element at the given indices
130
   */
131
  set(indices: readonly number[], value: Tensor.DataTypeMap[Tensor.DataType][number]) {
132
    this.data[ShapeUtil.indicesToOffset(indices, this.strides)] = value;
133
  }
134

135
  /**
136
   * get the underlying tensor data asynchronously
137
   */
138
  async getData(): Promise<TensorData> {
139
    if (this.cache === undefined) {
140
      this.cache = await this.asyncDataProvider!(this.dataId);
141
    }
142
    return this.cache;
143
  }
144

145
  /**
146
   * get the number of elements in the tensor
147
   */
148
  public readonly size: number;
149

150
  private _strides: readonly number[];
151
  /**
152
   * get the strides for each dimension
153
   */
154
  get strides(): readonly number[] {
155
    if (!this._strides) {
156
      this._strides = ShapeUtil.computeStrides(this.dims);
157
    }
158
    return this._strides;
159
  }
160

161
  constructor(
162
    /**
163
     * get the dimensions of the tensor
164
     */
165
    public readonly dims: readonly number[],
166
    /**
167
     * get the type of the tensor
168
     */
169
    public readonly type: Tensor.DataType,
170
    private dataProvider?: DataProvider,
171
    private asyncDataProvider?: AsyncDataProvider,
172
    private cache?: TensorData,
173
    /**
174
     * get the data ID that used to map to a tensor data
175
     */
176
    public readonly dataId: Guid = Guid.create(),
177
  ) {
178
    this.size = ShapeUtil.validateDimsAndCalcSize(dims);
179
    const size = this.size;
180
    const empty = dataProvider === undefined && asyncDataProvider === undefined && cache === undefined;
181

182
    if (cache !== undefined) {
183
      if (cache.length !== size) {
184
        throw new RangeError("Input dims doesn't match data length.");
185
      }
186
    }
187

188
    if (type === 'string') {
189
      if (cache !== undefined && (!Array.isArray(cache) || !cache.every((i) => typeof i === 'string'))) {
190
        throw new TypeError('cache should be a string array');
191
      }
192

193
      if (empty) {
194
        this.cache = new Array<string>(size);
195
      }
196
    } else {
197
      if (cache !== undefined) {
198
        const constructor = dataviewConstructor(type);
199
        if (!(cache instanceof constructor)) {
200
          throw new TypeError(`cache should be type ${constructor.name}`);
201
        }
202
      }
203

204
      if (empty) {
205
        const buf = new ArrayBuffer(size * sizeof(type));
206
        this.cache = createView(buf, type);
207
      }
208
    }
209
  }
210

211
  /**
212
   * Construct new Tensor from a ONNX Tensor object
213
   * @param tensorProto the ONNX Tensor
214
   */
215
  static fromProto(tensorProto: onnx.ITensorProto): Tensor {
216
    if (!tensorProto) {
217
      throw new Error('cannot construct Value from an empty tensor');
218
    }
219
    const type = ProtoUtil.tensorDataTypeFromProto(tensorProto.dataType!);
220
    const dims = ProtoUtil.tensorDimsFromProto(tensorProto.dims!);
221

222
    const value = new Tensor(dims, type);
223

224
    if (type === 'string') {
225
      // When it's STRING type, the value should always be stored in field
226
      // 'stringData'
227
      tensorProto.stringData!.forEach((str, i) => {
228
        value.data[i] = decodeUtf8String(str);
229
      });
230
    } else if (
231
      tensorProto.rawData &&
232
      typeof tensorProto.rawData.byteLength === 'number' &&
233
      tensorProto.rawData.byteLength > 0
234
    ) {
235
      // NOT considering segment for now (IMPORTANT)
236

237
      // populate value from rawData
238
      const dataDest = value.data;
239
      const dataSource = new DataView(
240
        tensorProto.rawData.buffer,
241
        tensorProto.rawData.byteOffset,
242
        tensorProto.rawData.byteLength,
243
      );
244
      const elementSize = sizeofProto(tensorProto.dataType!);
245
      const length = tensorProto.rawData.byteLength / elementSize;
246

247
      if (tensorProto.rawData.byteLength % elementSize !== 0) {
248
        throw new Error('invalid buffer length');
249
      }
250
      if (dataDest.length !== length) {
251
        throw new Error('buffer length mismatch');
252
      }
253

254
      for (let i = 0; i < length; i++) {
255
        const n = readProto(dataSource, tensorProto.dataType!, i * elementSize);
256
        dataDest[i] = n;
257
      }
258
    } else {
259
      // populate value from array
260
      let array: Array<number | Long>;
261
      switch (tensorProto.dataType) {
262
        case onnx.TensorProto.DataType.FLOAT:
263
          array = tensorProto.floatData!;
264
          break;
265
        case onnx.TensorProto.DataType.INT32:
266
        case onnx.TensorProto.DataType.INT16:
267
        case onnx.TensorProto.DataType.UINT16:
268
        case onnx.TensorProto.DataType.INT8:
269
        case onnx.TensorProto.DataType.UINT8:
270
        case onnx.TensorProto.DataType.BOOL:
271
          array = tensorProto.int32Data!;
272
          break;
273
        case onnx.TensorProto.DataType.INT64:
274
          array = tensorProto.int64Data!;
275
          break;
276
        case onnx.TensorProto.DataType.DOUBLE:
277
          array = tensorProto.doubleData!;
278
          break;
279
        case onnx.TensorProto.DataType.UINT32:
280
        case onnx.TensorProto.DataType.UINT64:
281
          array = tensorProto.uint64Data!;
282
          break;
283
        default:
284
          // should never run here
285
          throw new Error('unspecific error');
286
      }
287

288
      if (array === null || array === undefined) {
289
        throw new Error('failed to populate data from a tensorproto value');
290
      }
291

292
      const data = value.data;
293
      if (data.length !== array.length) {
294
        throw new Error('array length mismatch');
295
      }
296

297
      for (let i = 0; i < array.length; i++) {
298
        const element = array[i];
299
        if (Long.isLong(element)) {
300
          data[i] = longToNumber(element, tensorProto.dataType);
301
        } else {
302
          data[i] = element;
303
        }
304
      }
305
    }
306

307
    return value;
308
  }
309

310
  /**
311
   * Construct new Tensor from raw data
312
   * @param data the raw data object. Should be a string array for 'string' tensor, and the corresponding typed array
313
   * for other types of tensor.
314
   * @param dims the dimensions of the tensor
315
   * @param type the type of the tensor
316
   */
317
  static fromData(data: Tensor.DataTypeMap[Tensor.DataType], dims: readonly number[], type: Tensor.DataType) {
318
    return new Tensor(dims, type, undefined, undefined, data);
319
  }
320

321
  static fromOrtTensor(ortTensor: ortFbs.Tensor) {
322
    if (!ortTensor) {
323
      throw new Error('cannot construct Value from an empty tensor');
324
    }
325
    const dims = ProtoUtil.tensorDimsFromORTFormat(ortTensor);
326
    const type = ProtoUtil.tensorDataTypeFromProto(ortTensor.dataType());
327

328
    const value = new Tensor(dims, type);
329

330
    if (type === 'string') {
331
      // When it's STRING type, the value should always be stored in field
332
      // 'stringData'
333
      for (let i = 0; i < ortTensor.stringDataLength(); i++) {
334
        value.data[i] = ortTensor.stringData(i);
335
      }
336
    } else if (
337
      ortTensor.rawDataArray() &&
338
      typeof ortTensor.rawDataLength() === 'number' &&
339
      ortTensor.rawDataLength() > 0
340
    ) {
341
      // NOT considering segment for now (IMPORTANT)
342

343
      // populate value from rawData
344
      const dataDest = value.data;
345
      const dataSource = new DataView(
346
        ortTensor.rawDataArray()!.buffer,
347
        ortTensor.rawDataArray()!.byteOffset,
348
        ortTensor.rawDataLength(),
349
      );
350
      const elementSize = sizeofProto(ortTensor.dataType());
351
      const length = ortTensor.rawDataLength() / elementSize;
352

353
      if (ortTensor.rawDataLength() % elementSize !== 0) {
354
        throw new Error('invalid buffer length');
355
      }
356
      if (dataDest.length !== length) {
357
        throw new Error('buffer length mismatch');
358
      }
359

360
      for (let i = 0; i < length; i++) {
361
        const n = readProto(dataSource, ortTensor.dataType(), i * elementSize);
362
        dataDest[i] = n;
363
      }
364
    }
365
    return value;
366
  }
367
}
368

369
function sizeof(type: Tensor.DataType): number {
370
  switch (type) {
371
    case 'bool':
372
    case 'int8':
373
    case 'uint8':
374
      return 1;
375
    case 'int16':
376
    case 'uint16':
377
      return 2;
378
    case 'int32':
379
    case 'uint32':
380
    case 'float32':
381
      return 4;
382
    case 'float64':
383
      return 8;
384
    default:
385
      throw new Error(`cannot calculate sizeof() on type ${type}`);
386
  }
387
}
388

389
function sizeofProto(type: onnx.TensorProto.DataType | ortFbs.TensorDataType): number {
390
  switch (type) {
391
    case onnx.TensorProto.DataType.UINT8:
392
    case onnx.TensorProto.DataType.INT8:
393
    case onnx.TensorProto.DataType.BOOL:
394
      return 1;
395
    case onnx.TensorProto.DataType.UINT16:
396
    case onnx.TensorProto.DataType.INT16:
397
      return 2;
398
    case onnx.TensorProto.DataType.FLOAT:
399
    case onnx.TensorProto.DataType.INT32:
400
    case onnx.TensorProto.DataType.UINT32:
401
      return 4;
402
    case onnx.TensorProto.DataType.INT64:
403
    case onnx.TensorProto.DataType.DOUBLE:
404
    case onnx.TensorProto.DataType.UINT64:
405
      return 8;
406
    default:
407
      throw new Error(`cannot calculate sizeof() on type ${onnx.TensorProto.DataType[type]}`);
408
  }
409
}
410

411
function createView(dataBuffer: ArrayBuffer, type: Tensor.DataType) {
412
  return new (dataviewConstructor(type))(dataBuffer);
413
}
414

415
function dataviewConstructor(type: Tensor.DataType) {
416
  switch (type) {
417
    case 'bool':
418
    case 'uint8':
419
      return Uint8Array;
420
    case 'int8':
421
      return Int8Array;
422
    case 'int16':
423
      return Int16Array;
424
    case 'uint16':
425
      return Uint16Array;
426
    case 'int32':
427
      return Int32Array;
428
    case 'uint32':
429
      return Uint32Array;
430
    case 'int64':
431
      return BigInt64Array;
432
    case 'float32':
433
      return Float32Array;
434
    case 'float64':
435
      return Float64Array;
436
    default:
437
      // should never run to here
438
      throw new Error('unspecified error');
439
  }
440
}
441

442
// convert a long number to a 32-bit integer (cast-down)
443
function longToNumber(i: Long, type: onnx.TensorProto.DataType | ortFbs.TensorDataType): number {
444
  // INT64, UINT32, UINT64
445
  if (type === onnx.TensorProto.DataType.INT64 || type === ortFbs.TensorDataType.INT64) {
446
    if (i.greaterThanOrEqual(2147483648) || i.lessThan(-2147483648)) {
447
      throw new TypeError('int64 is not supported');
448
    }
449
  } else if (
450
    type === onnx.TensorProto.DataType.UINT32 ||
451
    type === ortFbs.TensorDataType.UINT32 ||
452
    type === onnx.TensorProto.DataType.UINT64 ||
453
    type === ortFbs.TensorDataType.UINT64
454
  ) {
455
    if (i.greaterThanOrEqual(4294967296) || i.lessThan(0)) {
456
      throw new TypeError('uint64 is not supported');
457
    }
458
  } else {
459
    throw new TypeError(`not a LONG type: ${onnx.TensorProto.DataType[type]}`);
460
  }
461

462
  return i.toNumber();
463
}
464

465
// read one value from TensorProto
466
function readProto(
467
  view: DataView,
468
  type: onnx.TensorProto.DataType | ortFbs.TensorDataType,
469
  byteOffset: number,
470
): number {
471
  switch (type) {
472
    case onnx.TensorProto.DataType.BOOL:
473
    case onnx.TensorProto.DataType.UINT8:
474
      return view.getUint8(byteOffset);
475
    case onnx.TensorProto.DataType.INT8:
476
      return view.getInt8(byteOffset);
477
    case onnx.TensorProto.DataType.UINT16:
478
      return view.getUint16(byteOffset, true);
479
    case onnx.TensorProto.DataType.INT16:
480
      return view.getInt16(byteOffset, true);
481
    case onnx.TensorProto.DataType.FLOAT:
482
      return view.getFloat32(byteOffset, true);
483
    case onnx.TensorProto.DataType.INT32:
484
      return view.getInt32(byteOffset, true);
485
    case onnx.TensorProto.DataType.UINT32:
486
      return view.getUint32(byteOffset, true);
487
    case onnx.TensorProto.DataType.INT64:
488
      return longToNumber(
489
        Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), false),
490
        type,
491
      );
492
    case onnx.TensorProto.DataType.DOUBLE:
493
      return view.getFloat64(byteOffset, true);
494
    case onnx.TensorProto.DataType.UINT64:
495
      return longToNumber(
496
        Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), true),
497
        type,
498
      );
499
    default:
500
      throw new Error(`cannot read from DataView for type ${onnx.TensorProto.DataType[type]}`);
501
  }
502
}
503

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.