1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { DataType } from '../../../wasm-common';
5
import { ShapeUtil } from '../../util';
6
import { ProgramUniform, ProgramUniformVariableInfo } from '../types';
9
* constant value for a workgroup size.
11
* We definitely can do further optimization in future, but for now we use 64.
13
* rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload
14
* needs something different.
16
* from: https://surma.dev/things/webgpu/
18
export const WORKGROUP_SIZE = 64;
20
interface IndicesHelperTypes {
22
* WGSL type of indices expression
24
readonly indices: string;
27
* WGSL type of a value
29
readonly value: string;
32
* WGSL type of storage type representing a value
34
* This is usually the same to `value`, but for some type (eg. bool), we need to use `u32` as storage type for
35
* value type `vec4<bool>`
37
readonly storage: string;
40
* tensor type as represented in TensorView
42
readonly tensor: number;
46
* A helper class for generating WGSL code for manipulating indices and data for a shader's input or output.
48
* This class is designed to offer a unified way to generate WGSL code for manipulating indices and data for a shader's
51
* The following is a list of terminologies used in this class:
52
* - `offset`: a uint32 value representing the offset of an element in the data buffer.
53
* - `indices`: an abstraction of a multi-dimensional array's indices representing the data's index on each dimension.
54
* - `value`: a value of a data element.
56
* Users are expected to create an instance of this class for each shader's input or output, and use the instance to
57
* generate WGSL code for manipulating indices and data. The following 2 exported functions are for users to call to
58
* create an instance of an indices helper:
59
* - `inputVariable()`: create an indices helper instance for an input.
60
* - `outputVariable()`: create an indices helper instance for an output.
61
* - `internalVariable()`: create an indices helper instance for an internal variable.
63
* An indices helper instance contains helper functions for the following operations:
64
* - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an
65
* input, an output or an internal variable) and `shape`(the passed in shape).
66
* - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at
67
* runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView).
68
* - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate
69
* indices from offset, and use `indicesToOffset()` for WGSL code snippet to calculate offset from indices.
70
* - to manipulate an instance of indices, use `setIndices()` and `getIndices()` to set and get the indices on an
72
* - to manipulate data, use `set()`/`get()` to access data at the given indices from parameter list, use
73
* `setByIndices()`/`getByIndices()` to access data at the given indices from an indices variable, and use
74
* `setByOffset()`/`getByOffset()` to access data at the given offset.
75
* - `impl`: get WGSL code of function implementation for the util functions mentioned above.
77
export interface IndicesHelper {
79
* get WGSL code of function implementation for the util functions.
82
readonly impl: () => string;
87
readonly type: IndicesHelperTypes;
90
* WGSL code of a expression for getting indices from offset.
92
* @param varOffset - a u32 expression representing the offset.
94
* @returns an `type.indices` expression
96
readonly offsetToIndices: (varOffset: string) => string;
99
* WGSL code of an `u32` expression for getting offset from indices.
101
* @param varIndices - a `type.indices` expression representing the indices.
103
* @returns an `u32` expression
105
readonly indicesToOffset: (varIndices: string) => string;
108
* WGSL code of an `u32` expression for getting original offset from broadcasted indices.
110
* @param varIndices - a `type.indices` expression representing the output indices.
111
* @param output - output IndicesHelper.
113
* @returns an `u32` expression
115
readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string;
118
* WGSL code of generating an indices literal
120
* @param init - initial value.
122
readonly indices: (...init: ReadonlyArray<number | string>) => string;
125
* WGSL code of a statement for setting indices.
127
* @param varIndices - a variable name for the indices.
128
* @param idx - the index of the indices to set. can be a number or a string (WGSL `u32` expression).
129
* @param value - the value to set. can be a number or a string (WGSL `u32` expression).
131
* @returns a WGSL statement
133
readonly indicesSet: (varIndices: string, idx: number | string, value: number | string) => void;
136
* WGSL code of an `u32` expression for getting indices.
138
* @param varIndices - a variable name for the indices.
139
* @param idx - the index of the indices to get. can be a number or a string (WGSL `u32` expression).
141
* @returns an `u32` expression
143
readonly indicesGet: (varIndices: string, idx: number | string) => string;
146
* WGSL code for a statement for setting data at the given indices.
148
* @param indicesAndValue - an array of numbers or strings (WGSL `u32` expression) representing the indices, followed
149
* by the value to set. This array should have exactly `shape.length + 1` elements.
151
readonly set: (...indicesAndValue: ReadonlyArray<number | string>) => string;
154
* WGSL code for a statement for setting data at the given indices variable.
156
* @param varIndices - a variable name for the indices.
157
* @param value - the value to set. should be a WGSL expression.
159
readonly setByIndices: (varIndices: string, value: string) => string;
162
* WGSL code for a statement for setting data at the given offset.
164
* @param offset - a number or a string (WGSL `u32` expression) representing the offset.
165
* @param value - the value to set. should be a WGSL expression.
167
readonly setByOffset: (offset: number | string, value: string) => string;
170
* WGSL code for an expression for getting data at the given indices.
172
* @param indices - an array of numbers or strings (WGSL `u32` expression) representing the indices.
174
readonly get: (...indices: ReadonlyArray<number | string>) => string;
177
* WGSL code for an expression for getting data at the given indices variable.
179
* @param varIndices - a variable name for the indices.
181
readonly getByIndices: (varIndices: string) => string;
184
* WGSL code for an expression for getting data at the given offset.
186
* @param offset - a number or a string (WGSL `u32` expression) representing the offset.
188
readonly getByOffset: (offset: number | string) => string;
191
* name of the data variable
193
readonly name: string;
196
* whether the helper is for an input, an output or an internal variable.
198
readonly usage: 'input' | 'output' | 'internal';
201
* the rank of the input or output.
203
readonly rank: number;
206
* a string representing the variable name for the shape of the input or output.
208
readonly shape: string;
211
* a string representing the variable name for the strides of the input or output.
213
readonly strides: string;
216
const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [string, string] => {
217
if (components === 3) {
218
throw new Error('vec3 has same alignment as vec4, use vec4 instead');
221
// return type is [ storage type, runtime type ] or a single string for both
223
case DataType.float16:
224
return components > 1 ? `vec${components}<f16>` : 'f16';
226
return components > 1 ? `vec${components}<f32>` : 'f32';
228
return components > 1 ? `vec${components}<i32>` : 'i32';
229
case DataType.uint32:
230
return components > 1 ? `vec${components}<u32>` : 'u32';
232
if (components > 1) {
233
throw new Error('currently not supported vecX of uint64 yet');
235
return ['vec2<u32>', 'i32'];
236
case DataType.uint64:
237
if (components > 1) {
238
throw new Error('currently not supported vecX of uint64 yet');
240
return ['vec2<u32>', 'u32'];
242
if (components !== 4) {
243
throw new Error('bool must be vec4');
245
return ['u32', 'vec4<bool>'];
251
throw new Error(`Unknown data type: ${type}`);
255
export const tensorTypeToWsglStorageType = (type: DataType, components: 1 | 2 | 3 | 4 = 1) => {
256
const mappedType = getWgslMappedType(type, components);
257
return typeof mappedType === 'string' ? mappedType : mappedType[0];
260
export const tensorTypeToWsglValueType = (type: DataType, components: 1 | 2 | 3 | 4 = 1) => {
261
const mappedType = getWgslMappedType(type, components);
262
return typeof mappedType === 'string' ? mappedType : mappedType[1];
265
export const createTensorShapeVariables = (...dims: ReadonlyArray<readonly number[]>): ProgramUniform[] => {
266
const programUniforms: ProgramUniform[] = [];
267
dims.forEach((dim) => {
268
if (dim.length !== 0) {
269
programUniforms.push(
270
{ type: DataType.uint32, data: dim },
271
{ type: DataType.uint32, data: ShapeUtil.computeStrides(dim) },
275
return programUniforms;
279
* A helper function to get maximum vector size for specified data length
282
export const getMaxComponents = (size: number) => {
283
// we cannot use vec3 type since it has alignment of 16 bytes
284
if (size % 4 === 0) {
286
} else if (size % 2 === 0) {
294
* A helper function that initializes variable as a scalar or vector. e.g. f32(0) or vec4f(0,0,0,0)
299
export const fillVector = (dataType = 'f32', components?: number, value = '0') => {
300
if (!components || components === 1) {
301
return `${dataType}(${value})`;
304
return `vec${components}<${dataType}>(${value})`;
308
* A helper function that casts value or vector to f32
313
export const castToF32 = (dataType: string, components: number, value: string) => {
314
if (dataType === 'f32') {
317
if (components === 1) {
318
return `f32(${value})`;
321
return `vec${components}<f32>(${value})`;
325
* A helper function that returns scalar or sums all components of a vector
329
export const sumVector = (name: string, components: number) => {
330
if (components === 4) {
331
return `(${name}.x + ${name}.y + ${name}.z + ${name}.w)`;
332
} else if (components === 2) {
333
return `(${name}.x + ${name}.y)`;
334
} else if (components === 3) {
335
return `(${name}.x + ${name}.y + ${name}.z)`;
342
* A helper function that returns variable element at index.
343
* @param name - the name of variable.
344
* @param index - the index of variable element.
345
* @param length - the length of variable.
346
* @param type - the type of variable, optional.
348
export const getElementAt = (
350
index: number | string,
352
type?: UniformDataElementType,
354
if (name.startsWith('uniforms.') && length > 4) {
355
if (typeof index === 'string') {
356
if (type === 'f16') {
357
return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`;
359
return `${name}[(${index}) / 4][(${index}) % 4]`;
362
if (type === 'f16') {
363
return `${name}[${Math.floor(index / 8)}][${Math.floor((index % 8) / 4)}][${(index % 8) % 4}]`;
365
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
369
return length > 1 ? `${name}[${index}]` : name;
374
* A helper function to get a IndicesHelper for a given input or output.
376
* @param name - the name of the input or output.
377
* @param tensorType - the tensor type of the input or output.
378
* @param shapeOrRank - the tensor shape or the rank of the input or output.
379
* @param usage - the usage of the indices helper.
380
* @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for
383
const createIndicesHelper = (
386
shapeOrRank: number | readonly number[],
387
usage: IndicesHelper['usage'],
388
components: 1 | 2 | 3 | 4,
390
const useUniform = typeof shapeOrRank === 'number';
391
const rank = useUniform ? shapeOrRank : shapeOrRank.length;
392
const rankIdentity = [...new Array(rank).keys()];
393
const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}<u32>` : `array<u32, ${rank}>`;
394
const mappedType = getWgslMappedType(tensorType, components);
395
const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1];
396
const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0];
397
const type = { indices: indicesType, value: valueType, storage: storageType, tensor: tensorType };
399
const normalizeDim = (dim: number | string): string => (typeof dim === 'string' ? dim : `${dim}u`);
401
const implementationUsed = {
402
offsetToIndices: false,
403
indicesToOffset: false,
404
broadcastedIndicesToOffset: false,
411
const uniformPrefix = useUniform ? 'uniforms.' : '';
412
const shape = `${uniformPrefix}${name}_shape`;
413
const strides = `${uniformPrefix}${name}_strides`;
416
for (let i = 0; i < rank - 1; i++) {
418
let dim${i} = current / ${getElementAt(strides, i, rank)};
419
let rest${i} = current % ${getElementAt(strides, i, rank)};
420
indices[${i}] = dim${i};
424
o2iSnippet += `indices[${rank - 1}] = current;`;
426
const offsetToIndicesImplementation =
430
fn o2i_${name}(offset: u32) -> ${type.indices} {
431
var indices: ${type.indices};
432
var current = offset;
437
const offsetToIndices = (varOffset: string) => {
438
implementationUsed.offsetToIndices = true;
439
return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
442
const offsets: string[] = [];
444
for (let i = rank - 1; i >= 0; i--) {
445
offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`);
449
const indicesToOffsetImplementation =
453
fn i2o_${name}(indices: ${type.indices}) -> u32 {
454
return ${offsets.join('+')};
457
const indicesToOffset = (varIndices: string) => {
458
implementationUsed.indicesToOffset = true;
459
return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
462
const indices = (...init: ReadonlyArray<number | string>) =>
463
rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`;
465
const indicesGet = (varIndices: string, idx: number | string) => {
467
return `${varIndices}`;
469
return `${getElementAt(varIndices, idx, rank)}`;
473
const indicesSet = (varIndices: string, idx: number | string, value: string) => {
475
return `${varIndices}=${value};`;
477
return `${getElementAt(varIndices, idx, rank)}=${value};`;
481
const broadcastedIndicesToOffsetImplementation: { [key: string]: string } = {};
482
const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => {
483
implementationUsed.broadcastedIndicesToOffset = true;
484
const implKey = `${output.name}broadcastedIndicesTo${name}Offset`;
485
if (implKey in broadcastedIndicesToOffsetImplementation) {
486
return `${implKey}(${varIndices})`;
489
for (let i = rank - 1; i >= 0; i--) {
490
const idx = output.indicesGet('outputIndices', i + output.rank - rank);
491
offsets.push(`${indicesGet(strides, i)} * (${idx} % ${indicesGet(shape, i)})`);
493
broadcastedIndicesToOffsetImplementation[implKey] = `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 {
494
return ${offsets.length > 0 ? offsets.join('+') : '0u'};
497
return `${implKey}(${varIndices})`;
500
const setByOffset = (offset: number | string, value: string) =>
502
if (type.storage === type.value) {
503
return `${name}[${offset}]=${value};`;
504
} else if (type.storage === 'vec2<u32>' && type.value === 'i32') {
505
// int64, components === 1
506
return `${name}[${offset}]=vec2<u32>(u32(${value}), select(0u, 0xFFFFFFFFu, ${value} < 0));`;
507
} else if (type.storage === 'vec2<u32>' && type.value === 'u32') {
508
// uint64, components === 1
509
return `${name}[${offset}]=vec2<u32>(u32(${value}), 0u);`;
510
} else if (type.storage === 'u32' && type.value === 'vec4<bool>') {
511
// bool, components === 4
512
return `${name}[${offset}]=dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(${value}));`;
514
throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`);
518
const getByOffset = (offset: number | string) =>
520
if (type.storage === type.value) {
521
return `${name}[${offset}]`;
522
} else if (type.storage === 'vec2<u32>' && type.value === 'i32') {
523
// int64, components === 1
524
return `i32(${name}[${offset}].x)`;
525
} else if (type.storage === 'vec2<u32>' && type.value === 'u32') {
526
// uint64, components === 1
527
return `u32(${name}[${offset}].x)`;
528
} else if (type.storage === 'u32' && type.value === 'vec4<bool>') {
529
// bool, components === 4
530
return `vec4<bool>(bool(${name}[${offset}] & 0xFFu), bool(${name}[${offset}] & 0xFF00u), bool(${name}[${
532
}] & 0xFF0000u), bool(${name}[${offset}] & 0xFF000000u))`;
534
throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`);
538
const getByIndicesImplementation =
542
fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
543
return ${getByOffset(`i2o_${name}(indices)`)};
546
const getImplementation =
550
const functionParams = rankIdentity.map((i) => `d${i}: u32`).join(', ');
551
const dimsParams = rankIdentity.map((i) => `d${i}`).join(', ');
553
fn get_${name}(${functionParams}) -> ${valueType} {
554
return get_${name}ByIndices(${indices(dimsParams)});
558
const get = (...indices: ReadonlyArray<number | string>) => {
559
if (indices.length !== rank) {
560
throw new Error(`indices length must be ${rank}`);
563
const normalizedIndices = indices.map(normalizeDim).join(',');
566
return getByOffset('0u');
567
} else if (rank === 1) {
568
return getByOffset(normalizedIndices[0]);
570
implementationUsed.get = true;
571
implementationUsed.getByIndices = true;
572
implementationUsed.indicesToOffset = true;
573
return `get_${name}(${normalizedIndices})`;
577
const getByIndices = (varIndices: string) => {
579
return getByOffset(varIndices);
581
implementationUsed.getByIndices = true;
582
implementationUsed.indicesToOffset = true;
583
return `get_${name}ByIndices(${varIndices})`;
587
const setByIndicesImplementation =
591
fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
592
${setByOffset(`i2o_${name}(indices)`, 'value')}
595
const setImplementation =
599
const functionParams = rankIdentity.map((i) => `d${i}: u32`).join(', ');
600
const dimsParams = rankIdentity.map((i) => `d${i}`).join(', ');
602
fn set_${name}(${functionParams}, value: ${valueType}) {
603
set_${name}ByIndices(${indices(dimsParams)}, value);
607
const set = (...indicesAndValue: ReadonlyArray<number | string>) => {
608
if (indicesAndValue.length !== rank + 1) {
609
throw new Error(`indices length must be ${rank}`);
611
const value = indicesAndValue[rank];
612
if (typeof value !== 'string') {
613
throw new Error('value must be string');
616
const normalizedIndices = indicesAndValue.slice(0, rank).map(normalizeDim).join(',');
619
return setByOffset('0u', value);
620
} else if (rank === 1) {
621
return setByOffset(normalizedIndices[0], value);
623
implementationUsed.set = true;
624
implementationUsed.setByIndices = true;
625
implementationUsed.indicesToOffset = true;
626
return `set_${name}(${normalizedIndices}, ${value})`;
630
const setByIndices = (varIndices: string, value: string) => {
632
return setByOffset(varIndices, value);
634
implementationUsed.setByIndices = true;
635
implementationUsed.indicesToOffset = true;
636
return `set_${name}ByIndices(${varIndices}, ${value});`;
642
let needShapeStrides = false;
643
if (implementationUsed.offsetToIndices) {
644
impls.push(offsetToIndicesImplementation);
645
needShapeStrides = true;
647
if (implementationUsed.indicesToOffset) {
648
impls.push(indicesToOffsetImplementation);
649
needShapeStrides = true;
651
if (implementationUsed.broadcastedIndicesToOffset) {
652
Object.values(broadcastedIndicesToOffsetImplementation).forEach((impl) => impls.push(impl));
653
needShapeStrides = true;
655
if (implementationUsed.set) {
656
impls.push(setImplementation);
657
needShapeStrides = true;
659
if (implementationUsed.setByIndices) {
660
impls.push(setByIndicesImplementation);
661
needShapeStrides = true;
663
if (implementationUsed.get) {
664
impls.push(getImplementation);
665
needShapeStrides = true;
667
if (implementationUsed.getByIndices) {
668
impls.push(getByIndicesImplementation);
669
needShapeStrides = true;
671
if (!useUniform && needShapeStrides) {
673
`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`,
674
`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`,
677
return impls.join('\n');
685
broadcastedIndicesToOffset,
705
* Create a IndicesHelper for an input.
707
* @param name - the name of the input.
708
* @param type - the tensor type of the input.
709
* @param shapeOrRank - the tensor shape or the rank of the input.
710
* @param components - the number of components of the input. available values are 1, 2, 3, 4. default is 1.
711
* @returns an IndicesHelper for the input.
713
export const inputVariable = (
716
shapeOrRank: number | readonly number[],
717
components: 1 | 2 | 3 | 4 = 1,
718
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'input', components);
721
* Create a IndicesHelper for an output.
723
* @param name - the name of the output.
724
* @param type - the tensor type of the output.
725
* @param shapeOrRank - the tensor shape or the rank of the output.
726
* @param components - the number of components of the output. available values are 1, 2, 3, 4. default is 1.
727
* @returns an IndicesHelper for the output.
729
export const outputVariable = (
732
shapeOrRank: number | readonly number[],
733
components: 1 | 2 | 3 | 4 = 1,
734
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'output', components);
737
* Create a IndicesHelper for an internal variable.
739
* @param name - the name of the variable.
740
* @param type - the tensor type of the variable.
741
* @param shapeOrRank - the tensor shape or the rank of the variable.
742
* @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1.
743
* @returns an IndicesHelper for the variable.
745
export const internalVariable = (
748
shapeOrRank: number | readonly number[],
749
components: 1 | 2 | 3 | 4 = 1,
750
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components);
752
export type UniformDataElementType = 'u32' | 'f16' | 'f32' | 'i32';
753
export type UniformsArrayType = Array<{ name: string; type: UniformDataElementType; length?: number }>;
756
* A ShaderHelper is a helper class for generating WGSL code.
758
export interface ShaderHelper {
760
* A helper function to generate the start of main function in WGSL source code.
763
* const getShaderSource = (shaderHelper: ShaderHelper) => `
766
* ${shaderHelper.mainStart()}
767
* // your code here inside main() function
772
* @param workgroupSize - an optional workgroup size. default is WORKGROUP_SIZE.
774
mainStart(workgroupSize?: number | [number, number, number]): string;
777
* A helper function to generate the code snippet for guarding against out-of-bounds size.
780
* const getShaderSource = (shaderHelper: ShaderHelper) => `
783
* ${shaderHelper.mainStart()}
784
* ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
786
* // your code here inside main() function
791
* @param size - the size of the data to guard against. can be a number or a string (WGSL `u32` expression).
793
guardAgainstOutOfBoundsWorkgroupSizes(size: unknown): string;
796
* A helper function to generate the code snippet for declaring multiple inputs or outputs.
798
* @param variables - an array of IndicesHelper for the variables.
800
declareVariables(...variables: IndicesHelper[]): string;
803
* A helper function to register one uniform. Can be called multiple times to register multiple uniforms.
805
* @param name - the name of the uniform.
806
* @param type - the type of the uniform.
807
* @param length - the length of the uniform, default to 1 when it is not provided.
809
registerUniform(name: string, type: string, length?: number): ShaderHelper;
812
* A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms.
814
* @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and
817
registerUniforms(uniforms: UniformsArrayType): ShaderHelper;
820
* A helper function to register multiple internal variables. Can be called multiple times to register multiple
821
* internal variables.
823
* @param variables - an array of IndicesHelper for the variables.
825
registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper;
828
class ShaderHelperImpl implements ShaderHelper {
830
private normalizedDispatchGroup: [number, number, number],
831
private limits: GPUSupportedLimits,
834
guardAgainstOutOfBoundsWorkgroupSizes(size: number | string): string {
835
// Guard against out-of-bounds work group sizes
836
const sizeInCode = typeof size === 'number' ? `${size}u` : size;
837
return `if (global_idx >= ${sizeInCode}) { return; }`;
840
mainStart(workgroupSize: number | [number, number, number] = WORKGROUP_SIZE) {
841
const workgroupSizeX = typeof workgroupSize === 'number' ? workgroupSize : workgroupSize[0];
842
const workgroupSizeY = typeof workgroupSize === 'number' ? 1 : workgroupSize[1];
843
const workgroupSizeZ = typeof workgroupSize === 'number' ? 1 : workgroupSize[2];
846
workgroupSizeX > this.limits.maxComputeWorkgroupSizeX ||
847
workgroupSizeY > this.limits.maxComputeWorkgroupSizeY ||
848
workgroupSizeZ > this.limits.maxComputeWorkgroupSizeZ
851
`workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${
853
}] exceeds the maximum workgroup size [${this.limits.maxComputeWorkgroupSizeX}, ${
854
this.limits.maxComputeWorkgroupSizeY
855
}, ${this.limits.maxComputeWorkgroupSizeZ}].`,
859
if (workgroupSizeX * workgroupSizeY * workgroupSizeZ > this.limits.maxComputeInvocationsPerWorkgroup) {
861
`workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${
863
}] exceeds the maximum workgroup invocations ${this.limits.maxComputeInvocationsPerWorkgroup}.`,
867
const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
868
const paramList = is1DimensionDispatch
869
? `@builtin(global_invocation_id) global_id : vec3<u32>,
870
@builtin(workgroup_id) workgroup_id : vec3<u32>,
871
@builtin(local_invocation_id) local_id : vec3<u32>`
872
: `@builtin(global_invocation_id) global_id : vec3<u32>,
873
@builtin(local_invocation_id) local_id : vec3<u32>,
874
@builtin(local_invocation_index) local_idx : u32,
875
@builtin(workgroup_id) workgroup_id : vec3<u32>,
876
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
877
const globalIdxDefinition = is1DimensionDispatch
878
? 'let global_idx = global_id.x; let local_idx = local_id.x;'
879
: `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
880
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
881
workgroupSizeX * workgroupSizeY * workgroupSizeZ
884
return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
885
fn main(${paramList}) {
886
${globalIdxDefinition}
890
private appendVariableUniforms(variable: IndicesHelper): void {
891
if (variable.rank !== 0) {
892
if (variable.shape.startsWith('uniforms.')) {
893
this.uniforms.push({ name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank });
895
if (variable.strides.startsWith('uniforms.')) {
896
this.uniforms.push({ name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank });
901
private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
902
if (variable.usage === 'internal') {
903
throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.');
905
this.variables.push(variable);
906
this.appendVariableUniforms(variable);
908
const access = variable.usage === 'input' ? 'read' : 'read_write';
909
const storageType = variable.type.storage;
910
return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
913
declareVariables(...variables: IndicesHelper[]): string {
914
return variables.map((v) => this.declareVariable(v, this.variableIndex++)).join('\n');
917
private registerInternalVariable(variable: IndicesHelper): void {
918
if (variable.usage !== 'internal') {
920
'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.',
924
this.internalVariables.push(variable);
925
this.appendVariableUniforms(variable);
928
registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper {
929
variables.forEach((v) => this.registerInternalVariable(v));
933
registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper {
934
this.uniforms.push({ name, type, length });
938
registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper {
939
this.uniforms = this.uniforms.concat(additionalUniforms);
943
private internalVariables: IndicesHelper[] = [];
944
private variables: IndicesHelper[] = [];
945
private uniforms: UniformsArrayType = [];
946
private uniformDeclaration(): string {
947
if (this.uniforms.length === 0) {
951
const uniformSnippets: string[] = [];
952
for (const { name, type, length } of this.uniforms) {
953
if (length && length > 4) {
954
if (type === 'f16') {
955
uniformSnippets.push(`@align(16) ${name}:array<mat2x4<${type}>, ${Math.ceil(length / 8)}>`);
957
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
960
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
961
uniformSnippets.push(`${name}:${typeTemp}`);
966
struct Uniforms { ${uniformSnippets.join(', ')} };
967
@group(0) @binding(${this.variableIndex}) var<uniform> uniforms: Uniforms;`;
969
private variableIndex = 0;
972
* Get additional implementation that needs to be added to the shader source.
974
get additionalImplementations(): string {
976
this.uniformDeclaration() +
977
this.variables.map((i) => i.impl()).join('\n') +
978
this.internalVariables.map((i) => i.impl()).join('\n')
983
* Get the variable info of the shader program.
985
get variablesInfo(): ProgramUniformVariableInfo[] | undefined {
986
if (this.uniforms.length === 0) {
990
const uniformWgslTypeToDataType = (type: UniformDataElementType) =>
991
[DataType.uint32, DataType.float16, DataType.float, DataType.int32][['u32', 'f16', 'f32', 'i32'].indexOf(type)];
992
return this.uniforms.map((u) => [uniformWgslTypeToDataType(u.type), u.length ?? 1]);
996
export const createShaderHelper = (dispatchGroup: [number, number, number], limits: GPUSupportedLimits) =>
997
new ShaderHelperImpl(dispatchGroup, limits);
1000
* This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40
1001
* Returns the dimensions in the input shape that are broadcasted to
1002
* produce the provided output shape.
1004
* The returned dimensions are 0-indexed and sorted. An example:
1005
* inShape = [4, 1, 3]
1006
* outShape = [5, 4, 3, 3]
1007
* result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
1009
export const getBroadcastDims = (inShape: readonly number[], outShape: readonly number[]): number[] => {
1010
const inRank = inShape.length;
1011
const dims: number[] = [];
1012
for (let i = 0; i < inRank; i++) {
1013
const dim = inRank - 1 - i;
1014
const a = inShape[dim] || 1;
1015
const b = outShape[outShape.length - 1 - i] || 1;
1016
if (b > 1 && a === 1) {