onnxruntime

Форк
0
558 строк · 17.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
/* eslint-disable no-param-reassign */
5

6
export class MatMulUtil {
7
  /**
8
   * Calculate the expected shape when matrix multiplication
9
   * @param a The shape of tensor A. Should be a tuple of 2 positive integers
10
   * @param b The shape of tensor B. Should be a tuple of 2 positive integers
11
   * @returns The expected shape of the result, or undefined if N/A
12
   */
13
  static calcMatMulShape(a: [number, number], b: [number, number]): [number, number] | undefined {
14
    return a[1] !== b[0] ? undefined : [a[0], b[1]];
15
  }
16
}
17

18
export class BroadcastUtil {
19
  /**
20
   * Calculate the expected shape when broadcasting 2 tensors
21
   * @param a The shape of tensor A. Should be an array of positive integers
22
   * @param b The shape of tensor B. Should be an array of positive integers
23
   * @param isMatMul Whether the operation is MatMul
24
   * @returns The expected shape of the result, or undefined if N/A
25
   */
26
  static calcShape(
27
    adims: readonly number[],
28
    bdims: readonly number[],
29
    isMatMul = false,
30
  ): readonly number[] | undefined {
31
    const arank = adims.length;
32
    const brank = bdims.length;
33
    if (arank === 0) {
34
      return bdims;
35
    }
36
    if (brank === 0) {
37
      return adims;
38
    }
39
    const crank = Math.max(adims.length, bdims.length);
40
    const cdims = new Array<number>(crank);
41

42
    // calculate the last 2 dimension if it is MatMul
43
    if (isMatMul) {
44
      if (arank < 2 || brank < 2) {
45
        return undefined;
46
      }
47
      const cShapeMatMul = MatMulUtil.calcMatMulShape(
48
        [adims[arank - 2], adims[arank - 1]],
49
        [bdims[brank - 2], bdims[brank - 1]],
50
      );
51
      if (cShapeMatMul === undefined) {
52
        return undefined;
53
      }
54
      [cdims[crank - 2], cdims[crank - 1]] = cShapeMatMul;
55
    }
56

57
    for (let i = isMatMul ? 3 : 1; i <= crank; i++) {
58
      const aLen = arank - i < 0 ? 1 : adims[arank - i];
59
      const bLen = brank - i < 0 ? 1 : bdims[brank - i];
60

61
      if (aLen !== bLen && aLen > 1 && bLen > 1) {
62
        return undefined;
63
      }
64
      const max = Math.max(aLen, bLen);
65
      if (aLen && bLen) {
66
        cdims[crank - i] = Math.max(aLen, bLen);
67
      } else {
68
        // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable.
69
        if (max > 1) {
70
          return undefined;
71
        }
72
        cdims[crank - i] = 0;
73
      }
74
    }
75

76
    return cdims;
77
  }
78

79
  /**
80
   * Determine if a shape is unidirectional broadcastable to another shape
81
   * @param shape The input shape
82
   * @param finalShape The desired shape after broadcasting
83
   */
84
  static isValidBroadcast(shape: readonly number[], finalShape: readonly number[]): boolean {
85
    // align shape to the right
86
    const inputRank = shape.length;
87
    const finalRank = finalShape.length;
88
    if (inputRank > finalRank) {
89
      return false;
90
    }
91
    for (let i = 1; i <= inputRank; i++) {
92
      if (shape[inputRank - i] !== 1 && shape[inputRank - i] !== finalShape[finalRank - i]) {
93
        return false;
94
      }
95
    }
96
    return true;
97
  }
98
}
99

100
export class ShapeUtil {
101
  /**
102
   * calculate the size (number of elements)
103
   */
104
  static size(dims: readonly number[]): number {
105
    return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length);
106
  }
107

108
  /**
109
   * convert dims corresponding to type change to pack. ex. uint8 data to uint32
110
   */
111
  static convertShape(dims: readonly number[], size = 4): readonly number[] {
112
    const rank = dims.length;
113
    if (rank === 0) {
114
      return [];
115
    }
116
    const newDims = new Array(rank);
117
    let i = rank - 1;
118
    while (i >= 0) {
119
      if (dims[i] % size === 0) {
120
        newDims[i] = dims[i] / size;
121
        break;
122
      }
123
      if (size % dims[i] !== 0) {
124
        throw new Error('cannot convert shape');
125
      }
126
      newDims[i] = 1;
127
      size /= dims[i];
128
      i--;
129
    }
130
    for (i--; i >= 0; i--) {
131
      newDims[i] = dims[i];
132
    }
133
    return newDims;
134
  }
135

136
  /**
137
   * calculate the size (number of elements) from the given axis (inclusive)
138
   */
139
  static sizeFromDimension(dims: readonly number[], axis: number): number {
140
    if (axis < 0 || axis > dims.length) {
141
      throw new Error(`invalid dimension of ${axis} for sizeFromDimension as Tensor has ${dims.length} dimensions.`);
142
    }
143
    return ShapeUtil.getSizeFromDimensionRange(dims, axis, dims.length);
144
  }
145

146
  /**
147
   * calculate the size (number of elements) to the given axis (exclusive)
148
   */
149
  static sizeToDimension(dims: readonly number[], axis: number): number {
150
    if (axis < 0 || axis > dims.length) {
151
      throw new Error(`invalid dimension of ${axis} for sizeToDimension as Tensor has ${dims.length} dimensions.`);
152
    }
153
    return ShapeUtil.getSizeFromDimensionRange(dims, 0, axis);
154
  }
155

156
  /**
157
   * calculate the size (number of elements) from and to the given axis [start, end)
158
   */
159
  static getSizeFromDimensionRange(dims: readonly number[], start: number, end: number): number {
160
    let size = 1;
161
    for (let i = start; i < end; i++) {
162
      // safety check as this method is called by multiple other methods requiring size.
163
      // size cannot be negative.
164
      if (dims[i] < 0) {
165
        throw new Error(
166
          // eslint-disable-next-line max-len
167
          'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.',
168
        );
169
      }
170
      size *= dims[i];
171
    }
172
    return size;
173
  }
174

175
  static computeStrides(dims: readonly number[]): readonly number[] {
176
    const rank = dims.length;
177
    if (rank === 0) {
178
      return [];
179
    } else if (rank === 1) {
180
      return [1];
181
    }
182
    const strides = new Array(rank);
183
    strides[rank - 1] = 1;
184
    strides[rank - 2] = dims[rank - 1];
185
    for (let i = rank - 3; i >= 0; --i) {
186
      strides[i] = strides[i + 1] * dims[i + 1];
187
    }
188
    return strides;
189
  }
190

191
  /**
192
   * normailze axis of range [-r, r) into [0, r).
193
   */
194
  static normalizeAxis(axis: number, tensorRank: number): number {
195
    if (axis < -tensorRank && axis >= tensorRank) {
196
      throw new Error('unsupported axis for this operation.');
197
    }
198
    return axis < 0 ? axis + tensorRank : axis;
199
  }
200

201
  static normalizeAxes(axes: readonly number[], tensorRank?: number): number[] {
202
    return axes.map((x) => this.normalizeAxis(x, tensorRank ?? axes.length));
203
  }
204

205
  /**
206
   * Sorts a given array based on the indices in the Perm array
207
   * Used in Transpose
208
   * @param a Array to be sorted such as dims or strides
209
   * @param perm Perm given; if null a will be reversed
210
   */
211
  static sortBasedOnPerm(a: readonly number[], perm?: readonly number[]): readonly number[] {
212
    if (perm) {
213
      return perm.map((v) => a[v]);
214
    } else {
215
      return a.slice().reverse();
216
    }
217
  }
218

219
  /**
220
   * Pads a given shape according to the padding values
221
   * @param dims shape of the Tensor to be padded
222
   * @param pad pad values
223
   */
224
  static padShape(dims: readonly number[], pad: readonly number[]): readonly number[] {
225
    const rank = dims.length;
226
    return dims.map((v, i) => v + pad[i] + pad[i + rank]);
227
  }
228

229
  /**
230
   * Determines if the two shapes are identical
231
   * @param shape1
232
   * @param shape2
233
   */
234
  static areEqual(shape1: readonly number[], shape2: readonly number[]): boolean {
235
    if (shape1.length !== shape2.length) {
236
      return false;
237
    }
238
    return shape1.every((v, i) => v === shape2[i]);
239
  }
240
}
241

242
export class PoolConvUtil {
243
  /**
244
   * Adjust the kernel, strides, pads to correct rank. Set to default value if not present
245
   * @param isGlobalOperator If true, perform global pooling.
246
   * @param inputDims The input tensor dimension.
247
   * @param kernelShape The size of the kernel along each axis.
248
   * @param strides Stride along each axis.
249
   * @param dilations Dilation along each axis.
250
   * @param pads Padding for the beginning and ending along each axis.
251
   */
252
  static adjustPoolAttributes(
253
    isGlobalOperator: boolean,
254
    inputDims: readonly number[],
255
    kernelShape: number[],
256
    strides: number[],
257
    dilations: number[],
258
    pads: number[],
259
  ): void {
260
    if (!isGlobalOperator && kernelShape.length !== inputDims.length - 2) {
261
      throw new Error('length of specified kernel shapes should be 2 less than length of input dimensions');
262
    }
263

264
    if (isGlobalOperator) {
265
      // adjust kernel shape to cover the input dims
266
      for (let dim = 0; dim < inputDims.length - 2; dim++) {
267
        if (dim >= kernelShape.length) {
268
          kernelShape.push(inputDims[dim + 2]);
269
        } else {
270
          kernelShape[dim] = inputDims[dim + 2];
271
        }
272
      }
273
    }
274

275
    // adjust strides length to match kernel shape length
276
    for (let dim = 0; dim < kernelShape.length; dim++) {
277
      if (dim < strides.length) {
278
        if (strides[dim] < 0) {
279
          throw new Error('strides should be greater than or equal to 1');
280
        }
281
      } else {
282
        strides.push(1);
283
      }
284
    }
285

286
    // adjust dilation value
287
    for (let dim = 0; dim < kernelShape.length; dim++) {
288
      if (dim < dilations.length) {
289
        if (dilations[dim] < 0) {
290
          throw new Error('dilations should be greater than or equal to 1');
291
        }
292
      } else {
293
        dilations.push(1);
294
      }
295
    }
296

297
    // adjust pads length to match 2 * kernel shape length
298
    for (let dim = 0; dim < kernelShape.length * 2; dim++) {
299
      if (dim < pads.length) {
300
        if (pads[dim] < 0) {
301
          throw new Error('pad should be greater than or equal to 1');
302
        }
303
      } else {
304
        pads.push(0);
305
      }
306
    }
307

308
    // sanity checks for values in kernel shapes and pads
309
    for (let dim = 0; dim < kernelShape.length; dim++) {
310
      if (kernelShape[dim] <= 0) {
311
        throw new Error('kernel shapes need to be greater than 0');
312
      }
313

314
      if (pads[dim] >= kernelShape[dim] || pads[dim + kernelShape.length] >= kernelShape[dim]) {
315
        throw new Error('pads should be smaller than kernel');
316
      }
317
    }
318
  }
319

320
  // adjust pad values based on 'autoPad' attribute
321
  static adjustPadsBasedOnAutoPad(
322
    inputDims: readonly number[],
323
    strides: readonly number[],
324
    dilations: readonly number[],
325
    kernelShape: readonly number[],
326
    pads: number[],
327
    isChannelLast: boolean,
328
    autoPad?: string,
329
  ): void {
330
    if (!autoPad) {
331
      return;
332
    }
333

334
    if (pads.length !== 2 * (inputDims.length - 2)) {
335
      throw new Error('length of pads should be twice the length of data dimensions');
336
    }
337

338
    if (strides.length !== inputDims.length - 2) {
339
      throw new Error('length of strides should be the length of data dimensions');
340
    }
341

342
    if (kernelShape.length !== inputDims.length - 2) {
343
      throw new Error('length of kernel shapes should be the length of data dimensions');
344
    }
345

346
    for (let dim = 0; dim < inputDims.length - 2; dim++) {
347
      PoolConvUtil.adjustPadAndReturnShape(
348
        inputDims[dim + (isChannelLast ? 1 : 2)],
349
        strides[dim],
350
        dilations[dim],
351
        kernelShape[dim],
352
        pads,
353
        dim,
354
        dim + inputDims.length - 2,
355
        autoPad,
356
      );
357
    }
358
  }
359

360
  /**
361
   * Calculate the output shape for Pool ops based on input attributes. (Should be used only for Pool ops)
362
   * @param isGlobalOperator If true, perform global pooling.
363
   * @param inputDims The input tensor dimension. (inputs[0].dims)
364
   * @param strides Stride along each axis.
365
   * @param dilations Dilation along each axis.
366
   * @param kernelShape The size of the kernel along each axis.
367
   * @param pads Padding for the beginning and ending along each axis.
368
   * @param autoPad DEPRECATED attribute supported for legacy models. Specifies how to implicitly calculate pads in each
369
   *     dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID.
370
   */
371
  static computePoolOutputShape(
372
    isGlobalOperator: boolean,
373
    inputDims: readonly number[],
374
    strides: number[],
375
    dilations: number[],
376
    kernelShape: number[],
377
    pads: number[],
378
    autoPad?: string,
379
  ): number[] {
380
    if (inputDims.length <= 0) {
381
      throw new Error('input shape must be of size greater than 0');
382
    }
383

384
    // Add batch size and number of channels of output
385
    const outputDims = [inputDims[0], inputDims[1]];
386

387
    PoolConvUtil.computeShapeHelper(
388
      isGlobalOperator,
389
      inputDims,
390
      outputDims,
391
      strides,
392
      dilations,
393
      kernelShape,
394
      pads,
395
      autoPad,
396
    );
397
    return outputDims;
398
  }
399

400
  /**
401
   * Calculate the output shape for Conv op based on input attributes. (Should be used only for Conv op)
402
   * @param inputDims The input tensor dimension. (inputs[0].dims)
403
   * @param filterDims The filter tensor dimension. (inputs[1].dims)
404
   * @param strides Stride along each axis.
405
   * @param kernelShape The size of the kernel along each axis.
406
   * @param pads Padding for the beginning and ending along each axis.
407
   * @param autoPad DEPRECATED attribute supported for legacy models. Specifies how to implicitly calculate pads in each
408
   *     dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID.
409
   */
410
  static computeConvOutputShape(
411
    inputDims: readonly number[],
412
    filterDims: readonly number[],
413
    strides: number[],
414
    dilations: number[],
415
    kernelShape: number[],
416
    pads: number[],
417
    autoPad?: string,
418
  ): number[] {
419
    if (inputDims.length <= 0 || filterDims.length <= 0) {
420
      throw new Error('invalid input tensor dims or invalid filter tensor dims');
421
    }
422

423
    // Add batch size and number of channels of output
424
    const outputDims = [inputDims[0], filterDims[0]];
425

426
    PoolConvUtil.computeShapeHelper(false, inputDims, outputDims, strides, dilations, kernelShape, pads, autoPad);
427
    return outputDims;
428
  }
429

430
  // will compute output shapes for data dimensions ONLY (i.e.) no batch size and channels
431
  // called by computePoolOutputShape() and computeConvOutputShape()
432
  // adjust pads based on 'autoPad' attribute prior to shape computation
433
  private static computeShapeHelper(
434
    isGlobalOperator: boolean,
435
    inputDims: readonly number[],
436
    outputDims: number[],
437
    strides: readonly number[],
438
    dilations: readonly number[],
439
    kernelShape: readonly number[],
440
    pads: number[],
441
    autoPad?: string,
442
  ) {
443
    if (isGlobalOperator) {
444
      for (let dim = 0; dim < inputDims.length - 2; dim++) {
445
        outputDims.push(1);
446
      }
447
    } else {
448
      for (let dim = 0; dim < inputDims.length - 2; dim++) {
449
        outputDims.push(
450
          PoolConvUtil.adjustPadAndReturnShape(
451
            inputDims[dim + 2],
452
            strides[dim],
453
            dilations[dim],
454
            kernelShape[dim],
455
            pads,
456
            dim,
457
            dim + inputDims.length - 2,
458
            autoPad,
459
          ),
460
        );
461
      }
462
    }
463
  }
464

465
  // helper for computeShapeHelper() and adjustPadsBasedOnAutoPad()
466
  // adjusts pad value for given 'autoPad' string and computes output shape along a particular dimension
467
  private static adjustPadAndReturnShape(
468
    inSize: number,
469
    stride: number,
470
    dilation: number,
471
    kernel: number,
472
    pads: number[],
473
    padHeadIndex: number,
474
    padTailIndex: number,
475
    autoPad?: string,
476
  ): number {
477
    const dkernel = dilation * (kernel - 1) + 1;
478
    if (autoPad && autoPad !== 'NOTSET') {
479
      switch (autoPad) {
480
        case 'VALID':
481
          pads[padHeadIndex] = 0;
482
          pads[padTailIndex] = 0;
483
          return Math.floor((inSize - dkernel) / stride + 1);
484
        case 'SAME_LOWER':
485
        case 'SAME_UPPER':
486
          if (dilation !== 1) {
487
            throw new Error('Dilation not supported for SAME_UPPER or SAME_LOWER');
488
          } else {
489
            const legacyTargetSize = (inSize + stride - 1) / stride;
490
            const padNeeded = (legacyTargetSize - 1) * stride + kernel - inSize;
491
            pads[padHeadIndex] = autoPad === 'SAME_LOWER' ? Math.floor((padNeeded + 1) / 2) : Math.floor(padNeeded / 2);
492
            pads[padTailIndex] = padNeeded - pads[padHeadIndex];
493
            return Math.floor((inSize + padNeeded - kernel) / stride + 1);
494
          }
495
        default:
496
          throw new Error('Unsupported AutoPad type');
497
      }
498
    } else {
499
      return Math.floor((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride + 1);
500
    }
501
  }
502
}
503

504
export class GemmUtil {
505
  // will make sure input shapes are compatible for this op
506
  // and return back the shape of the output in the form of a tuple
507
  // will throw exception if the input shapes are not compatible
508
  static getShapeOfGemmResult(
509
    leftShape: readonly number[],
510
    transLeft: boolean,
511
    rightShape: readonly number[],
512
    transRight: boolean,
513
    biasShape?: readonly number[],
514
  ): readonly number[] {
515
    if (leftShape.length !== 2 || rightShape.length !== 2) {
516
      throw new Error('shape need to be of size 2');
517
    }
518

519
    let M: number;
520
    let K: number;
521
    let N: number;
522

523
    if (transLeft) {
524
      M = leftShape[1];
525
      K = leftShape[0];
526
    } else {
527
      M = leftShape[0];
528
      K = leftShape[1];
529
    }
530

531
    let kDim = -1;
532

533
    if (transRight) {
534
      N = rightShape[0];
535
      kDim = 1;
536
    } else {
537
      N = rightShape[1];
538
      kDim = 0;
539
    }
540

541
    if (rightShape[kDim] !== K) {
542
      throw new Error('dimension mismatch');
543
    }
544

545
    if (M <= 0 || N <= 0 || K <= 0) {
546
      throw new Error('invalid shape specified');
547
    }
548

549
    if (biasShape && !BroadcastUtil.isValidBroadcast(biasShape, [M, N])) {
550
      throw new Error('gemm: invalid bias shape for broadcast');
551
    }
552

553
    return [M, N, K];
554
  }
555
}
556

557
export const MIN_CLIP = -3.4028234663852886e38;
558
export const MAX_CLIP = 3.4028234663852886e38;
559

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.