onnxruntime

Форк
0
84 строки · 3.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { env } from 'onnxruntime-common';
5

6
import { DataType } from '../../../wasm-common';
7
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
8

9
import {
10
  createTensorShapeVariables,
11
  outputVariable,
12
  ShaderHelper,
13
  UniformDataElementType,
14
  UniformsArrayType,
15
} from './common';
16

17
const validateInputsContent = (start: number, limit: number, delta: number): void => {
18
  const sameStartLimit = start === limit;
19
  const increasingRangeNegativeStep = start < limit && delta < 0;
20
  const decreasingRangePositiveStep = start > limit && delta > 0;
21

22
  if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) {
23
    throw new Error("Range these inputs' contents are invalid.");
24
  }
25
};
26

27
const createRangeProgramInfo = (start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => {
28
  const numElements = Math.abs(Math.ceil((limit - start) / delta));
29
  const outputShape: number[] = [numElements];
30
  const outputSize = numElements;
31
  const programUniforms: ProgramUniform[] = [
32
    { type: DataType.uint32, data: outputSize },
33
    { type: dataType, data: start },
34
    { type: dataType, data: delta },
35
    ...createTensorShapeVariables(outputShape),
36
  ];
37

38
  const getShaderSource = (shaderHelper: ShaderHelper) => {
39
    const output = outputVariable('output', dataType, outputShape.length);
40
    const wgslType = output.type.value;
41
    const uniforms: UniformsArrayType = [
42
      { name: 'outputSize', type: 'u32' },
43
      { name: 'start', type: wgslType as UniformDataElementType },
44
      { name: 'delta', type: wgslType as UniformDataElementType },
45
    ];
46
    return `
47
        ${shaderHelper.registerUniforms(uniforms).declareVariables(output)}
48
        ${shaderHelper.mainStart()}
49
        ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
50
        output[global_idx] = uniforms.start + ${wgslType}(global_idx) * uniforms.delta;
51
      }`;
52
  };
53

54
  return {
55
    name: 'Range',
56
    shaderCache: { hint: `${dataType}` },
57
    getShaderSource,
58
    getRunData: () => ({
59
      outputs: [{ dims: outputShape, dataType }],
60
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
61
      programUniforms,
62
    }),
63
  };
64
};
65

66
export const range = (context: ComputeContext): void => {
67
  let start = 0;
68
  let limit = 0;
69
  let delta = 0;
70
  if (context.inputs[0].dataType === DataType.int32) {
71
    start = context.inputs[0].getInt32Array()[0];
72
    limit = context.inputs[1].getInt32Array()[0];
73
    delta = context.inputs[2].getInt32Array()[0];
74
  } else if (context.inputs[0].dataType === DataType.float) {
75
    start = context.inputs[0].getFloat32Array()[0];
76
    limit = context.inputs[1].getFloat32Array()[0];
77
    delta = context.inputs[2].getFloat32Array()[0];
78
  }
79
  if (env.webgpu.validateInputContent) {
80
    validateInputsContent(start, limit, delta);
81
  }
82

83
  context.compute(createRangeProgramInfo(start, limit, delta, context.inputs[0].dataType), { inputs: [] });
84
};
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.