onnxruntime

Форк
0
341 строка · 11.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { BroadcastUtil, ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo } from '../types';
8

9
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
10

11
type BuiltinFunctionName = string;
12
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
13
type BinaryFunctionCall =
14
  | BuiltinFunctionName
15
  | BinaryCustomExpression
16
  | {
17
      scalar: BinaryCustomExpression;
18
      vector: BinaryCustomExpression;
19
    };
20

21
const createBinaryOpProgramShader = (
22
  shaderHelper: ShaderHelper,
23
  dimsA: readonly number[],
24
  dimsB: readonly number[],
25
  dimsOutput: readonly number[],
26
  vectorize: boolean,
27
  doBroadcast: boolean,
28
  sharedDimensionDivisibleBy4: boolean,
29
  funcCall: BinaryFunctionCall,
30
  typeA: number,
31
  typeB: number,
32
  typeOutput: number,
33
  additionalImplementation?: string,
34
) => {
35
  let expressionScalar: BinaryCustomExpression;
36
  let expressionVector: BinaryCustomExpression;
37
  if (typeof funcCall === 'string') {
38
    expressionScalar = expressionVector = (a, b) => `${funcCall}((${a}),(${b}))`;
39
  } else if (typeof funcCall === 'function') {
40
    expressionScalar = expressionVector = funcCall;
41
  } else {
42
    expressionScalar = funcCall.scalar;
43
    expressionVector = funcCall.vector;
44
  }
45

46
  const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4);
47
  const a = inputVariable('aData', typeA, dimsA.length, 4);
48
  const b = inputVariable('bData', typeB, dimsB.length, 4);
49

50
  let assignment: string;
51
  if (vectorize) {
52
    if (doBroadcast) {
53
      const isAOneElement = ShapeUtil.size(dimsA) === 1;
54
      const isBOneElement = ShapeUtil.size(dimsB) === 1;
55
      const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0;
56
      const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0;
57
      if (isAOneElement || isBOneElement) {
58
        assignment = output.setByOffset(
59
          'global_idx',
60
          expressionVector(
61
            isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'),
62
            isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'),
63
          ),
64
        );
65
      } else {
66
        assignment = `
67
            let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
68
            let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)};
69
            let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)};
70
            ${output.setByOffset(
71
              'global_idx',
72
              expressionVector(
73
                sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4
74
                  ? a.getByOffset('offsetA / 4u')
75
                  : `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`,
76
                sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4
77
                  ? b.getByOffset('offsetB / 4u')
78
                  : `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`,
79
              ),
80
            )}
81
          `;
82
      }
83
    } else {
84
      assignment = output.setByOffset(
85
        'global_idx',
86
        expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')),
87
      );
88
    }
89
  } else {
90
    if (!doBroadcast) {
91
      throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.');
92
    }
93

94
    const singleAssignment = (resStr: string, x: number, typeCast = '') => {
95
      const expressionA = `aData[indexA${x}][componentA${x}]`;
96
      const expressionB = `bData[indexB${x}][componentB${x}]`;
97
      return `
98
            let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
99
            let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
100
            let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
101
            let indexA${x} = offsetA${x} / 4u;
102
            let indexB${x} = offsetB${x} / 4u;
103
            let componentA${x} = offsetA${x} % 4u;
104
            let componentB${x} = offsetB${x} % 4u;
105
            ${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)});
106
          `;
107
    };
108
    if (typeOutput === DataType.bool) {
109
      assignment = `
110
            var data = vec4<u32>(0);
111
            ${singleAssignment('data', 0, 'u32')}
112
            ${singleAssignment('data', 1, 'u32')}
113
            ${singleAssignment('data', 2, 'u32')}
114
            ${singleAssignment('data', 3, 'u32')}
115
            outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
116
    } else {
117
      assignment = `
118
            ${singleAssignment('outputData[global_idx]', 0)}
119
            ${singleAssignment('outputData[global_idx]', 1)}
120
            ${singleAssignment('outputData[global_idx]', 2)}
121
            ${singleAssignment('outputData[global_idx]', 3)}
122
          `;
123
    }
124
  }
125

126
  return `
127
        ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(a, b, output)}
128

129
        ${additionalImplementation ?? ''}
130

131
        ${shaderHelper.mainStart()}
132
        ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
133
        ${assignment}
134
      }`;
135
};
136

137
const createBinaryOpProgramInfo = (
138
  name: string,
139
  cacheKey: string,
140
  a: TensorView,
141
  b: TensorView,
142
  funcCall: BinaryFunctionCall,
143
  additionalImplementation?: string,
144
  outputDataType: number = a.dataType,
145
): ProgramInfo => {
146
  const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims);
147
  let outputShape = a.dims;
148
  let outputSize = ShapeUtil.size(a.dims);
149

150
  let vectorize = false;
151
  let sharedDimensionDivisibleBy4 = false;
152

153
  // TODO: deal with zero-sized tensors (eg. dims=[1,0])
154
  const cacheKeyAux = [isBroadcast];
155
  if (isBroadcast) {
156
    const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false);
157
    if (!calculatedShape) {
158
      throw new Error("Can't perform binary op on the given tensors");
159
    }
160
    outputShape = calculatedShape;
161
    outputSize = ShapeUtil.size(outputShape);
162
    const isAOneElement = ShapeUtil.size(a.dims) === 1;
163
    const isBOneElement = ShapeUtil.size(b.dims) === 1;
164
    const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0;
165
    const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0;
166
    cacheKeyAux.push(isAOneElement);
167
    cacheKeyAux.push(isBOneElement);
168
    cacheKeyAux.push(aLastDimDivisibleBy4);
169
    cacheKeyAux.push(bLastDimDivisibleBy4);
170
    // check whether vectorize can be enabled
171
    let sharedDimension = 1;
172
    for (let i = 1; i < outputShape.length; i++) {
173
      const dimA = a.dims[a.dims.length - i] ?? 1;
174
      const dimB = b.dims[b.dims.length - i] ?? 1;
175
      if (dimA === dimB) {
176
        sharedDimension *= dimA;
177
      } else {
178
        break;
179
      }
180
    }
181
    if (sharedDimension % 4 === 0) {
182
      sharedDimensionDivisibleBy4 = true;
183
      vectorize = true;
184
    } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) {
185
      vectorize = true;
186
    }
187
  } else {
188
    // element-wise
189
    vectorize = true;
190
  }
191
  cacheKeyAux.push(vectorize);
192

193
  return {
194
    name,
195
    shaderCache: {
196
      hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'),
197
      inputDependencies: ['rank', 'rank'],
198
    },
199
    getShaderSource: (shaderHelper) =>
200
      createBinaryOpProgramShader(
201
        shaderHelper,
202
        a.dims,
203
        b.dims,
204
        outputShape,
205
        vectorize,
206
        isBroadcast,
207
        sharedDimensionDivisibleBy4,
208
        funcCall,
209
        a.dataType,
210
        b.dataType,
211
        outputDataType,
212
        additionalImplementation,
213
      ),
214
    getRunData: () => ({
215
      outputs: [{ dims: outputShape, dataType: outputDataType }],
216
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */) },
217
      programUniforms: [
218
        { type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4) },
219
        ...createTensorShapeVariables(a.dims, b.dims, outputShape),
220
      ],
221
    }),
222
  };
223
};
224

225
const runBinaryOp = (
226
  context: ComputeContext,
227
  name: string,
228
  funcCall: BinaryFunctionCall,
229
  additionalImplementation?: string,
230
  cacheKey?: string,
231
  outputDataType?: number,
232
): void => {
233
  context.compute(
234
    createBinaryOpProgramInfo(
235
      name,
236
      cacheKey ?? '',
237
      context.inputs[0],
238
      context.inputs[1],
239
      funcCall,
240
      additionalImplementation,
241
      outputDataType,
242
    ),
243
  );
244
};
245

246
export const add = (context: ComputeContext): void => {
247
  runBinaryOp(context, 'Add', (a, b) => `${a}+${b}`);
248
};
249

250
export const div = (context: ComputeContext): void => {
251
  runBinaryOp(context, 'Div', (a, b) => `${a}/${b}`);
252
};
253

254
export const equal = (context: ComputeContext): void => {
255
  runBinaryOp(
256
    context,
257
    'Equal',
258
    { scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4<u32>(${a}==${b})` },
259
    undefined,
260
    undefined,
261
    DataType.bool,
262
  );
263
};
264

265
export const mul = (context: ComputeContext): void => {
266
  runBinaryOp(context, 'Mul', (a, b) => `${a}*${b}`);
267
};
268

269
export const pow = (context: ComputeContext): void => {
270
  const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value;
271
  const roundStr = type === 'i32' ? 'round' : '';
272
  runBinaryOp(
273
    context,
274
    'Pow',
275
    { scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})` },
276
    `
277
    fn pow_custom(a : ${type}, b : ${type}) -> ${type} {
278
      if (b == ${type}(0.0)) {
279
        return ${type}(1.0);
280
      } else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) {
281
        return ${type}(pow(f32(a), f32(b))); // NaN
282
      }
283
      return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${
284
        roundStr
285
      }(pow(f32(abs(a)), f32(b))));
286
    }
287
    fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> {
288
      // TODO: implement vectorized pow
289
      return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w));
290
    }
291
      `,
292
  );
293
};
294

295
export const sub = (context: ComputeContext): void => {
296
  runBinaryOp(context, 'Sub', (a, b) => `${a}-${b}`);
297
};
298

299
export const greater = (context: ComputeContext): void => {
300
  runBinaryOp(
301
    context,
302
    'Greater',
303
    { scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4<u32>(${a}>${b})` },
304
    undefined,
305
    undefined,
306
    DataType.bool,
307
  );
308
};
309

310
export const less = (context: ComputeContext): void => {
311
  runBinaryOp(
312
    context,
313
    'Less',
314
    { scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4<u32>(${a}<${b})` },
315
    undefined,
316
    undefined,
317
    DataType.bool,
318
  );
319
};
320

321
export const greaterOrEqual = (context: ComputeContext): void => {
322
  runBinaryOp(
323
    context,
324
    'GreaterOrEqual',
325
    { scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})` },
326
    undefined,
327
    undefined,
328
    DataType.bool,
329
  );
330
};
331

332
export const lessOrEqual = (context: ComputeContext): void => {
333
  runBinaryOp(
334
    context,
335
    'LessOrEqual',
336
    { scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4<u32>(${a}<=${b})` },
337
    undefined,
338
    undefined,
339
    DataType.bool,
340
  );
341
};
342

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.