1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { BroadcastUtil, ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo } from '../types';
9
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
11
type BuiltinFunctionName = string;
12
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
13
type BinaryFunctionCall =
15
| BinaryCustomExpression
17
scalar: BinaryCustomExpression;
18
vector: BinaryCustomExpression;
21
const createBinaryOpProgramShader = (
22
shaderHelper: ShaderHelper,
23
dimsA: readonly number[],
24
dimsB: readonly number[],
25
dimsOutput: readonly number[],
28
sharedDimensionDivisibleBy4: boolean,
29
funcCall: BinaryFunctionCall,
33
additionalImplementation?: string,
35
let expressionScalar: BinaryCustomExpression;
36
let expressionVector: BinaryCustomExpression;
37
if (typeof funcCall === 'string') {
38
expressionScalar = expressionVector = (a, b) => `${funcCall}((${a}),(${b}))`;
39
} else if (typeof funcCall === 'function') {
40
expressionScalar = expressionVector = funcCall;
42
expressionScalar = funcCall.scalar;
43
expressionVector = funcCall.vector;
46
const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4);
47
const a = inputVariable('aData', typeA, dimsA.length, 4);
48
const b = inputVariable('bData', typeB, dimsB.length, 4);
50
let assignment: string;
53
const isAOneElement = ShapeUtil.size(dimsA) === 1;
54
const isBOneElement = ShapeUtil.size(dimsB) === 1;
55
const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0;
56
const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0;
57
if (isAOneElement || isBOneElement) {
58
assignment = output.setByOffset(
61
isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'),
62
isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'),
67
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
68
let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)};
69
let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)};
73
sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4
74
? a.getByOffset('offsetA / 4u')
75
: `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`,
76
sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4
77
? b.getByOffset('offsetB / 4u')
78
: `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`,
84
assignment = output.setByOffset(
86
expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')),
91
throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.');
94
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
95
const expressionA = `aData[indexA${x}][componentA${x}]`;
96
const expressionB = `bData[indexB${x}][componentB${x}]`;
98
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
99
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
100
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
101
let indexA${x} = offsetA${x} / 4u;
102
let indexB${x} = offsetB${x} / 4u;
103
let componentA${x} = offsetA${x} % 4u;
104
let componentB${x} = offsetB${x} % 4u;
105
${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)});
108
if (typeOutput === DataType.bool) {
110
var data = vec4<u32>(0);
111
${singleAssignment('data', 0, 'u32')}
112
${singleAssignment('data', 1, 'u32')}
113
${singleAssignment('data', 2, 'u32')}
114
${singleAssignment('data', 3, 'u32')}
115
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
118
${singleAssignment('outputData[global_idx]', 0)}
119
${singleAssignment('outputData[global_idx]', 1)}
120
${singleAssignment('outputData[global_idx]', 2)}
121
${singleAssignment('outputData[global_idx]', 3)}
127
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(a, b, output)}
129
${additionalImplementation ?? ''}
131
${shaderHelper.mainStart()}
132
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
137
const createBinaryOpProgramInfo = (
142
funcCall: BinaryFunctionCall,
143
additionalImplementation?: string,
144
outputDataType: number = a.dataType,
146
const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims);
147
let outputShape = a.dims;
148
let outputSize = ShapeUtil.size(a.dims);
150
let vectorize = false;
151
let sharedDimensionDivisibleBy4 = false;
153
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
154
const cacheKeyAux = [isBroadcast];
156
const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false);
157
if (!calculatedShape) {
158
throw new Error("Can't perform binary op on the given tensors");
160
outputShape = calculatedShape;
161
outputSize = ShapeUtil.size(outputShape);
162
const isAOneElement = ShapeUtil.size(a.dims) === 1;
163
const isBOneElement = ShapeUtil.size(b.dims) === 1;
164
const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0;
165
const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0;
166
cacheKeyAux.push(isAOneElement);
167
cacheKeyAux.push(isBOneElement);
168
cacheKeyAux.push(aLastDimDivisibleBy4);
169
cacheKeyAux.push(bLastDimDivisibleBy4);
170
// check whether vectorize can be enabled
171
let sharedDimension = 1;
172
for (let i = 1; i < outputShape.length; i++) {
173
const dimA = a.dims[a.dims.length - i] ?? 1;
174
const dimB = b.dims[b.dims.length - i] ?? 1;
176
sharedDimension *= dimA;
181
if (sharedDimension % 4 === 0) {
182
sharedDimensionDivisibleBy4 = true;
184
} else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) {
191
cacheKeyAux.push(vectorize);
196
hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'),
197
inputDependencies: ['rank', 'rank'],
199
getShaderSource: (shaderHelper) =>
200
createBinaryOpProgramShader(
207
sharedDimensionDivisibleBy4,
212
additionalImplementation,
215
outputs: [{ dims: outputShape, dataType: outputDataType }],
216
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */) },
218
{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4) },
219
...createTensorShapeVariables(a.dims, b.dims, outputShape),
226
context: ComputeContext,
228
funcCall: BinaryFunctionCall,
229
additionalImplementation?: string,
231
outputDataType?: number,
234
createBinaryOpProgramInfo(
240
additionalImplementation,
246
export const add = (context: ComputeContext): void => {
247
runBinaryOp(context, 'Add', (a, b) => `${a}+${b}`);
250
export const div = (context: ComputeContext): void => {
251
runBinaryOp(context, 'Div', (a, b) => `${a}/${b}`);
254
export const equal = (context: ComputeContext): void => {
258
{ scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4<u32>(${a}==${b})` },
265
export const mul = (context: ComputeContext): void => {
266
runBinaryOp(context, 'Mul', (a, b) => `${a}*${b}`);
269
export const pow = (context: ComputeContext): void => {
270
const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value;
271
const roundStr = type === 'i32' ? 'round' : '';
275
{ scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})` },
277
fn pow_custom(a : ${type}, b : ${type}) -> ${type} {
278
if (b == ${type}(0.0)) {
280
} else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) {
281
return ${type}(pow(f32(a), f32(b))); // NaN
283
return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${
285
}(pow(f32(abs(a)), f32(b))));
287
fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> {
288
// TODO: implement vectorized pow
289
return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w));
295
export const sub = (context: ComputeContext): void => {
296
runBinaryOp(context, 'Sub', (a, b) => `${a}-${b}`);
299
export const greater = (context: ComputeContext): void => {
303
{ scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4<u32>(${a}>${b})` },
310
export const less = (context: ComputeContext): void => {
314
{ scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4<u32>(${a}<${b})` },
321
export const greaterOrEqual = (context: ComputeContext): void => {
325
{ scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})` },
332
export const lessOrEqual = (context: ComputeContext): void => {
336
{ scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4<u32>(${a}<=${b})` },