onnxruntime

Форк
0
116 строк · 4.8 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { BroadcastUtil, ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo } from '../types';
8

9
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
10

11
const createWhereOpProgramShader = (
12
  shaderHelper: ShaderHelper,
13
  inputs: readonly TensorView[],
14
  dimsOutput: readonly number[],
15
  isBroadcast: boolean,
16
  typeOutput: number,
17
) => {
18
  const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4);
19
  const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4);
20
  const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4);
21
  const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4);
22

23
  let assignment: string;
24
  const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
25
  if (!isBroadcast) {
26
    assignment = output.setByOffset(
27
      'global_idx',
28
      expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')),
29
    );
30
  } else {
31
    const singleAssignment = (resStr: string, x: number, typeCast = '') => {
32
      const expressionA = `a_data[index_a${x}][component_a${x}]`;
33
      const expressionB = `b_data[index_b${x}][component_b${x}]`;
34
      // eslint-disable-next-line no-bitwise
35
      const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
36
      return `
37
            let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
38
            let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
39
            let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)};
40
            let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)};
41
            let index_a${x} = offset_a${x} / 4u;
42
            let index_b${x} = offset_b${x} / 4u;
43
            let index_c${x} = offset_c${x} / 4u;
44
            let component_a${x} = offset_a${x} % 4u;
45
            let component_b${x} = offset_b${x} % 4u;
46
            let component_c${x} = offset_c${x} % 4u;
47
            ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
48
          `;
49
    };
50
    if (typeOutput === DataType.bool) {
51
      assignment = `
52
            var data = vec4<u32>(0);
53
            ${singleAssignment('data', 0, 'u32')}
54
            ${singleAssignment('data', 1, 'u32')}
55
            ${singleAssignment('data', 2, 'u32')}
56
            ${singleAssignment('data', 3, 'u32')}
57
            output_data[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
58
    } else {
59
      assignment = `
60
            ${singleAssignment('output_data[global_idx]', 0)}
61
            ${singleAssignment('output_data[global_idx]', 1)}
62
            ${singleAssignment('output_data[global_idx]', 2)}
63
            ${singleAssignment('output_data[global_idx]', 3)}
64
          `;
65
    }
66
  }
67

68
  return `
69
        ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)}
70
        ${shaderHelper.mainStart()}
71
        ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
72
        ${assignment}
73
      }`;
74
};
75

76
const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => {
77
  const dimsA = inputs[1].dims;
78
  const dimsB = inputs[2].dims;
79
  const dimsC = inputs[0].dims;
80
  const outputDataType = inputs[1].dataType;
81

82
  const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
83
  let outputShape = dimsA;
84
  let outputSize = ShapeUtil.size(dimsA);
85
  // TODO: deal with zero-sized tensors (eg. dims=[1,0])
86

87
  if (isBroadcast) {
88
    const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false);
89
    if (!calculatedShape) {
90
      throw new Error("Can't perform where op on the given tensors");
91
    }
92
    outputShape = calculatedShape;
93
    outputSize = ShapeUtil.size(outputShape);
94
  }
95

96
  const vecSize = Math.ceil(outputSize / 4);
97

98
  return {
99
    name: 'Where',
100
    shaderCache: { inputDependencies: ['rank', 'rank', 'rank'] },
101
    getShaderSource: (shaderHelper) =>
102
      createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
103
    getRunData: () => ({
104
      outputs: [{ dims: outputShape, dataType: outputDataType }],
105
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */) },
106
      programUniforms: [
107
        { type: DataType.uint32, data: vecSize },
108
        ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape),
109
      ],
110
    }),
111
  };
112
};
113

114
export const where = (context: ComputeContext): void => {
115
  context.compute(createWhereOpProgramInfo(context.inputs));
116
};
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.