onnxruntime

Форк
0
97 строк · 3.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { assert } from '../../util';
5
/**
6
 * Given a non RGBA shape calculate the R version
7
 * It is assumed that the dimensions are multiples of given channels
8
 * NOTE: it is always the last dim that gets packed.
9
 * @param unpackedShape original shape to create a packed version from
10
 */
11
export function getPackedShape(unpackedShape: readonly number[]): readonly number[] {
12
  const len = unpackedShape.length;
13
  return unpackedShape.slice(0, len - 1).concat(unpackedShape[len - 1] / 4);
14
}
15

16
export async function repeatedTry(
17
  checkFn: () => boolean,
18
  delayFn = (_counter: number) => 0,
19
  maxCounter?: number,
20
): Promise<void> {
21
  return new Promise<void>((resolve, reject) => {
22
    let tryCount = 0;
23

24
    const tryFn = () => {
25
      if (checkFn()) {
26
        resolve();
27
        return;
28
      }
29

30
      tryCount++;
31

32
      const nextBackoff = delayFn(tryCount);
33

34
      if (maxCounter != null && tryCount >= maxCounter) {
35
        reject();
36
        return;
37
      }
38
      setTimeout(tryFn, nextBackoff);
39
    };
40

41
    tryFn();
42
  });
43
}
44

45
/**
46
 * Generates the function name from an input sampler name.
47
 * @param samplerName Name of the sampler.
48
 */
49
export function generateShaderFuncNameFromInputSamplerName(samplerName: string): string {
50
  assert(typeof samplerName !== 'undefined' && samplerName.length !== 0, () => 'empty string found for sampler name');
51
  return 'get' + samplerName.charAt(0).toUpperCase() + samplerName.slice(1);
52
}
53

54
/**
55
 * Generates the function name from an input sampler name at output coordinates.
56
 * @param samplerName Name of the sampler.
57
 */
58
export function generateShaderFuncNameFromInputSamplerNameAtOutCoords(samplerName: string): string {
59
  assert(typeof samplerName !== 'undefined' && samplerName.length !== 0, () => 'empty string found for sampler name');
60
  return 'get' + samplerName.charAt(0).toUpperCase() + samplerName.slice(1) + 'AtOutCoords';
61
}
62

63
/** Returns a new input shape (a copy) that has a squeezed logical shape. */
64
export function squeezeInputShape(inputShape: readonly number[], squeezedShape: number[]): number[] {
65
  // Deep copy.
66
  let newInputShape: number[] = JSON.parse(JSON.stringify(inputShape));
67
  newInputShape = squeezedShape;
68
  return newInputShape;
69
}
70

71
/** Returns a list of squeezed parameters for shader functions */
72
export function getSqueezedParams(params: string[], keptDims: number[]): string {
73
  return keptDims.map((d) => params[d]).join(', ');
74
}
75

76
/** Returns the data type for different ranks. */
77
export function getCoordsDataType(rank: number): string {
78
  if (rank <= 1) {
79
    return 'int';
80
  } else if (rank === 2) {
81
    return 'ivec2';
82
  } else if (rank === 3) {
83
    return 'ivec3';
84
  } else if (rank === 4) {
85
    return 'ivec4';
86
  } else if (rank === 5) {
87
    return 'ivec5';
88
  } else if (rank === 6) {
89
    return 'ivec6';
90
  } else {
91
    throw Error(`GPU for rank ${rank} is not yet supported`);
92
  }
93
}
94

95
export function getGlChannels(rank = 6): string[] {
96
  return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
97
}
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.