onnxruntime

Форк
0
1021 строка · 34.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { ShapeUtil } from '../../util';
6
import { ProgramUniform, ProgramUniformVariableInfo } from '../types';
7

8
/**
9
 * constant value for a workgroup size.
10
 *
11
 * We definitely can do further optimization in future, but for now we use 64.
12
 *
13
 * rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload
14
 *                needs something different.
15
 *
16
 * from: https://surma.dev/things/webgpu/
17
 **/
18
export const WORKGROUP_SIZE = 64;
19

20
interface IndicesHelperTypes {
21
  /**
22
   * WGSL type of indices expression
23
   */
24
  readonly indices: string;
25

26
  /**
27
   * WGSL type of a value
28
   */
29
  readonly value: string;
30

31
  /**
32
   * WGSL type of storage type representing a value
33
   *
34
   * This is usually the same to `value`, but for some type (eg. bool), we need to use `u32` as storage type for
35
   * value type `vec4<bool>`
36
   */
37
  readonly storage: string;
38

39
  /**
40
   * tensor type as represented in TensorView
41
   */
42
  readonly tensor: number;
43
}
44

45
/**
46
 * A helper class for generating WGSL code for manipulating indices and data for a shader's input or output.
47
 *
48
 * This class is designed to offer a unified way to generate WGSL code for manipulating indices and data for a shader's
49
 * input or output.
50
 *
51
 * The following is a list of terminologies used in this class:
52
 * - `offset`: a uint32 value representing the offset of an element in the data buffer.
53
 * - `indices`: an abstraction of a multi-dimensional array's indices representing the data's index on each dimension.
54
 * - `value`: a value of a data element.
55
 *
56
 * Users are expected to create an instance of this class for each shader's input or output, and use the instance to
57
 * generate WGSL code for manipulating indices and data. The following 2 exported functions are for users to call to
58
 * create an instance of an indices helper:
59
 * - `inputVariable()`: create an indices helper instance for an input.
60
 * - `outputVariable()`: create an indices helper instance for an output.
61
 * - `internalVariable()`: create an indices helper instance for an internal variable.
62
 *
63
 * An indices helper instance contains helper functions for the following operations:
64
 * - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an
65
 * input, an output or an internal variable) and `shape`(the passed in shape).
66
 * - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at
67
 * runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView).
68
 * - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate
69
 * indices from offset, and use `indicesToOffset()` for WGSL code snippet to calculate offset from indices.
70
 * - to manipulate an instance of indices, use `setIndices()` and `getIndices()` to set and get the indices on an
71
 * indices variable.
72
 * - to manipulate data, use `set()`/`get()` to access data at the given indices from parameter list, use
73
 * `setByIndices()`/`getByIndices()` to access data at the given indices from an indices variable, and use
74
 * `setByOffset()`/`getByOffset()` to access data at the given offset.
75
 * - `impl`: get WGSL code of function implementation for the util functions mentioned above.
76
 */
77
export interface IndicesHelper {
78
  /**
79
   * get WGSL code of function implementation for the util functions.
80
   *
81
   */
82
  readonly impl: () => string;
83

84
  /**
85
   * get type info
86
   */
87
  readonly type: IndicesHelperTypes;
88

89
  /**
90
   * WGSL code of a expression for getting indices from offset.
91
   *
92
   * @param varOffset - a u32 expression representing the offset.
93
   *
94
   * @returns an `type.indices` expression
95
   */
96
  readonly offsetToIndices: (varOffset: string) => string;
97

98
  /**
99
   * WGSL code of an `u32` expression for getting offset from indices.
100
   *
101
   * @param varIndices - a `type.indices` expression representing the indices.
102
   *
103
   * @returns an `u32` expression
104
   */
105
  readonly indicesToOffset: (varIndices: string) => string;
106

107
  /**
108
   * WGSL code of an `u32` expression for getting original offset from broadcasted indices.
109
   *
110
   * @param varIndices - a `type.indices` expression representing the output indices.
111
   * @param output - output IndicesHelper.
112
   *
113
   * @returns an `u32` expression
114
   */
115
  readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string;
116

117
  /**
118
   * WGSL code of generating an indices literal
119
   *
120
   * @param init - initial value.
121
   */
122
  readonly indices: (...init: ReadonlyArray<number | string>) => string;
123

124
  /**
125
   * WGSL code of a statement for setting indices.
126
   *
127
   * @param varIndices - a variable name for the indices.
128
   * @param idx - the index of the indices to set. can be a number or a string (WGSL `u32` expression).
129
   * @param value - the value to set. can be a number or a string (WGSL `u32` expression).
130
   *
131
   * @returns a WGSL statement
132
   */
133
  readonly indicesSet: (varIndices: string, idx: number | string, value: number | string) => void;
134

135
  /**
136
   * WGSL code of an `u32` expression for getting indices.
137
   *
138
   * @param varIndices - a variable name for the indices.
139
   * @param idx - the index of the indices to get. can be a number or a string (WGSL `u32` expression).
140
   *
141
   * @returns an `u32` expression
142
   */
143
  readonly indicesGet: (varIndices: string, idx: number | string) => string;
144

145
  /**
146
   * WGSL code for a statement for setting data at the given indices.
147
   *
148
   * @param indicesAndValue - an array of numbers or strings (WGSL `u32` expression) representing the indices, followed
149
   *     by the value to set. This array should have exactly `shape.length + 1` elements.
150
   */
151
  readonly set: (...indicesAndValue: ReadonlyArray<number | string>) => string;
152

153
  /**
154
   * WGSL code for a statement for setting data at the given indices variable.
155
   *
156
   * @param varIndices - a variable name for the indices.
157
   * @param value - the value to set. should be a WGSL expression.
158
   */
159
  readonly setByIndices: (varIndices: string, value: string) => string;
160

161
  /**
162
   * WGSL code for a statement for setting data at the given offset.
163
   *
164
   * @param offset - a number or a string (WGSL `u32` expression) representing the offset.
165
   * @param value - the value to set. should be a WGSL expression.
166
   */
167
  readonly setByOffset: (offset: number | string, value: string) => string;
168

169
  /**
170
   * WGSL code for an expression for getting data at the given indices.
171
   *
172
   * @param indices - an array of numbers or strings (WGSL `u32` expression) representing the indices.
173
   */
174
  readonly get: (...indices: ReadonlyArray<number | string>) => string;
175

176
  /**
177
   * WGSL code for an expression for getting data at the given indices variable.
178
   *
179
   * @param varIndices - a variable name for the indices.
180
   */
181
  readonly getByIndices: (varIndices: string) => string;
182

183
  /**
184
   * WGSL code for an expression for getting data at the given offset.
185
   *
186
   * @param offset - a number or a string (WGSL `u32` expression) representing the offset.
187
   */
188
  readonly getByOffset: (offset: number | string) => string;
189

190
  /**
191
   * name of the data variable
192
   */
193
  readonly name: string;
194

195
  /**
196
   * whether the helper is for an input, an output or an internal variable.
197
   */
198
  readonly usage: 'input' | 'output' | 'internal';
199

200
  /**
201
   * the rank of the input or output.
202
   */
203
  readonly rank: number;
204

205
  /**
206
   * a string representing the variable name for the shape of the input or output.
207
   */
208
  readonly shape: string;
209

210
  /**
211
   * a string representing the variable name for the strides of the input or output.
212
   */
213
  readonly strides: string;
214
}
215

216
const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [string, string] => {
217
  if (components === 3) {
218
    throw new Error('vec3 has same alignment as vec4, use vec4 instead');
219
  }
220

221
  // return type is [ storage type, runtime type ] or a single string for both
222
  switch (type) {
223
    case DataType.float16:
224
      return components > 1 ? `vec${components}<f16>` : 'f16';
225
    case DataType.float:
226
      return components > 1 ? `vec${components}<f32>` : 'f32';
227
    case DataType.int32:
228
      return components > 1 ? `vec${components}<i32>` : 'i32';
229
    case DataType.uint32:
230
      return components > 1 ? `vec${components}<u32>` : 'u32';
231
    case DataType.int64:
232
      if (components > 1) {
233
        throw new Error('currently not supported vecX of uint64 yet');
234
      }
235
      return ['vec2<u32>', 'i32'];
236
    case DataType.uint64:
237
      if (components > 1) {
238
        throw new Error('currently not supported vecX of uint64 yet');
239
      }
240
      return ['vec2<u32>', 'u32'];
241
    case DataType.bool:
242
      if (components !== 4) {
243
        throw new Error('bool must be vec4');
244
      }
245
      return ['u32', 'vec4<bool>'];
246
    case DataType.int4:
247
      return 'i32';
248
    case DataType.uint4:
249
      return 'u32';
250
    default:
251
      throw new Error(`Unknown data type: ${type}`);
252
  }
253
};
254

255
export const tensorTypeToWsglStorageType = (type: DataType, components: 1 | 2 | 3 | 4 = 1) => {
256
  const mappedType = getWgslMappedType(type, components);
257
  return typeof mappedType === 'string' ? mappedType : mappedType[0];
258
};
259

260
export const tensorTypeToWsglValueType = (type: DataType, components: 1 | 2 | 3 | 4 = 1) => {
261
  const mappedType = getWgslMappedType(type, components);
262
  return typeof mappedType === 'string' ? mappedType : mappedType[1];
263
};
264

265
export const createTensorShapeVariables = (...dims: ReadonlyArray<readonly number[]>): ProgramUniform[] => {
266
  const programUniforms: ProgramUniform[] = [];
267
  dims.forEach((dim) => {
268
    if (dim.length !== 0) {
269
      programUniforms.push(
270
        { type: DataType.uint32, data: dim },
271
        { type: DataType.uint32, data: ShapeUtil.computeStrides(dim) },
272
      );
273
    }
274
  });
275
  return programUniforms;
276
};
277

278
/**
279
 * A helper function to get maximum vector size for specified data length
280
 * @param size
281
 */
282
export const getMaxComponents = (size: number) => {
283
  // we cannot use vec3 type since it has alignment of 16 bytes
284
  if (size % 4 === 0) {
285
    return 4;
286
  } else if (size % 2 === 0) {
287
    return 2;
288
  }
289

290
  return 1;
291
};
292

293
/**
294
 * A helper function that initializes variable as a scalar or vector. e.g. f32(0) or vec4f(0,0,0,0)
295
 * @param dataType
296
 * @param components
297
 * @param value
298
 */
299
export const fillVector = (dataType = 'f32', components?: number, value = '0') => {
300
  if (!components || components === 1) {
301
    return `${dataType}(${value})`;
302
  }
303

304
  return `vec${components}<${dataType}>(${value})`;
305
};
306

307
/**
308
 * A helper function that casts value or vector to f32
309
 * @param dataType
310
 * @param components
311
 * @param value
312
 */
313
export const castToF32 = (dataType: string, components: number, value: string) => {
314
  if (dataType === 'f32') {
315
    return value;
316
  }
317
  if (components === 1) {
318
    return `f32(${value})`;
319
  }
320

321
  return `vec${components}<f32>(${value})`;
322
};
323

324
/**
325
 * A helper function that returns scalar or sums all components of a vector
326
 * @param name
327
 * @param components
328
 */
329
export const sumVector = (name: string, components: number) => {
330
  if (components === 4) {
331
    return `(${name}.x + ${name}.y + ${name}.z + ${name}.w)`;
332
  } else if (components === 2) {
333
    return `(${name}.x + ${name}.y)`;
334
  } else if (components === 3) {
335
    return `(${name}.x + ${name}.y + ${name}.z)`;
336
  }
337

338
  return name;
339
};
340

341
/**
342
 * A helper function that returns variable element at index.
343
 * @param name - the name of variable.
344
 * @param index - the index of variable element.
345
 * @param length - the length of variable.
346
 * @param type - the type of variable, optional.
347
 */
348
export const getElementAt = (
349
  name: string,
350
  index: number | string,
351
  length: number,
352
  type?: UniformDataElementType,
353
): string => {
354
  if (name.startsWith('uniforms.') && length > 4) {
355
    if (typeof index === 'string') {
356
      if (type === 'f16') {
357
        return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`;
358
      } else {
359
        return `${name}[(${index}) / 4][(${index}) % 4]`;
360
      }
361
    } else {
362
      if (type === 'f16') {
363
        return `${name}[${Math.floor(index / 8)}][${Math.floor((index % 8) / 4)}][${(index % 8) % 4}]`;
364
      } else {
365
        return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
366
      }
367
    }
368
  } else {
369
    return length > 1 ? `${name}[${index}]` : name;
370
  }
371
};
372

373
/**
374
 * A helper function to get a IndicesHelper for a given input or output.
375
 *
376
 * @param name - the name of the input or output.
377
 * @param tensorType - the tensor type of the input or output.
378
 * @param shapeOrRank - the tensor shape or the rank of the input or output.
379
 * @param usage - the usage of the indices helper.
380
 * @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for
381
 *    vec4.
382
 */
383
const createIndicesHelper = (
384
  name: string,
385
  tensorType: number,
386
  shapeOrRank: number | readonly number[],
387
  usage: IndicesHelper['usage'],
388
  components: 1 | 2 | 3 | 4,
389
): IndicesHelper => {
390
  const useUniform = typeof shapeOrRank === 'number';
391
  const rank = useUniform ? shapeOrRank : shapeOrRank.length;
392
  const rankIdentity = [...new Array(rank).keys()];
393
  const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}<u32>` : `array<u32, ${rank}>`;
394
  const mappedType = getWgslMappedType(tensorType, components);
395
  const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1];
396
  const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0];
397
  const type = { indices: indicesType, value: valueType, storage: storageType, tensor: tensorType };
398

399
  const normalizeDim = (dim: number | string): string => (typeof dim === 'string' ? dim : `${dim}u`);
400

401
  const implementationUsed = {
402
    offsetToIndices: false,
403
    indicesToOffset: false,
404
    broadcastedIndicesToOffset: false,
405
    set: false,
406
    setByIndices: false,
407
    get: false,
408
    getByIndices: false,
409
  };
410

411
  const uniformPrefix = useUniform ? 'uniforms.' : '';
412
  const shape = `${uniformPrefix}${name}_shape`;
413
  const strides = `${uniformPrefix}${name}_strides`;
414

415
  let o2iSnippet = '';
416
  for (let i = 0; i < rank - 1; i++) {
417
    o2iSnippet += `
418
    let dim${i} = current / ${getElementAt(strides, i, rank)};
419
    let rest${i} = current % ${getElementAt(strides, i, rank)};
420
    indices[${i}] = dim${i};
421
    current = rest${i};
422
    `;
423
  }
424
  o2iSnippet += `indices[${rank - 1}] = current;`;
425

426
  const offsetToIndicesImplementation =
427
    rank < 2
428
      ? ''
429
      : `
430
  fn o2i_${name}(offset: u32) -> ${type.indices} {
431
    var indices: ${type.indices};
432
    var current = offset;
433
    ${o2iSnippet}
434
    return indices;
435
  }`;
436

437
  const offsetToIndices = (varOffset: string) => {
438
    implementationUsed.offsetToIndices = true;
439
    return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
440
  };
441

442
  const offsets: string[] = [];
443
  if (rank >= 2) {
444
    for (let i = rank - 1; i >= 0; i--) {
445
      offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`);
446
    }
447
  }
448

449
  const indicesToOffsetImplementation =
450
    rank < 2
451
      ? ''
452
      : `
453
  fn i2o_${name}(indices: ${type.indices}) -> u32 {
454
    return ${offsets.join('+')};
455
  }`;
456

457
  const indicesToOffset = (varIndices: string) => {
458
    implementationUsed.indicesToOffset = true;
459
    return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
460
  };
461

462
  const indices = (...init: ReadonlyArray<number | string>) =>
463
    rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`;
464

465
  const indicesGet = (varIndices: string, idx: number | string) => {
466
    if (rank < 2) {
467
      return `${varIndices}`;
468
    } else {
469
      return `${getElementAt(varIndices, idx, rank)}`;
470
    }
471
  };
472

473
  const indicesSet = (varIndices: string, idx: number | string, value: string) => {
474
    if (rank < 2) {
475
      return `${varIndices}=${value};`;
476
    } else {
477
      return `${getElementAt(varIndices, idx, rank)}=${value};`;
478
    }
479
  };
480

481
  const broadcastedIndicesToOffsetImplementation: { [key: string]: string } = {};
482
  const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => {
483
    implementationUsed.broadcastedIndicesToOffset = true;
484
    const implKey = `${output.name}broadcastedIndicesTo${name}Offset`;
485
    if (implKey in broadcastedIndicesToOffsetImplementation) {
486
      return `${implKey}(${varIndices})`;
487
    }
488
    const offsets = [];
489
    for (let i = rank - 1; i >= 0; i--) {
490
      const idx = output.indicesGet('outputIndices', i + output.rank - rank);
491
      offsets.push(`${indicesGet(strides, i)} * (${idx} % ${indicesGet(shape, i)})`);
492
    }
493
    broadcastedIndicesToOffsetImplementation[implKey] = `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 {
494
             return ${offsets.length > 0 ? offsets.join('+') : '0u'};
495
           }`;
496

497
    return `${implKey}(${varIndices})`;
498
  };
499

500
  const setByOffset = (offset: number | string, value: string) =>
501
    (() => {
502
      if (type.storage === type.value) {
503
        return `${name}[${offset}]=${value};`;
504
      } else if (type.storage === 'vec2<u32>' && type.value === 'i32') {
505
        // int64, components === 1
506
        return `${name}[${offset}]=vec2<u32>(u32(${value}), select(0u, 0xFFFFFFFFu, ${value} < 0));`;
507
      } else if (type.storage === 'vec2<u32>' && type.value === 'u32') {
508
        // uint64, components === 1
509
        return `${name}[${offset}]=vec2<u32>(u32(${value}), 0u);`;
510
      } else if (type.storage === 'u32' && type.value === 'vec4<bool>') {
511
        // bool, components === 4
512
        return `${name}[${offset}]=dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(${value}));`;
513
      } else {
514
        throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`);
515
      }
516
    })();
517

518
  const getByOffset = (offset: number | string) =>
519
    (() => {
520
      if (type.storage === type.value) {
521
        return `${name}[${offset}]`;
522
      } else if (type.storage === 'vec2<u32>' && type.value === 'i32') {
523
        // int64, components === 1
524
        return `i32(${name}[${offset}].x)`;
525
      } else if (type.storage === 'vec2<u32>' && type.value === 'u32') {
526
        // uint64, components === 1
527
        return `u32(${name}[${offset}].x)`;
528
      } else if (type.storage === 'u32' && type.value === 'vec4<bool>') {
529
        // bool, components === 4
530
        return `vec4<bool>(bool(${name}[${offset}] & 0xFFu), bool(${name}[${offset}] & 0xFF00u), bool(${name}[${
531
          offset
532
        }] & 0xFF0000u), bool(${name}[${offset}] & 0xFF000000u))`;
533
      } else {
534
        throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`);
535
      }
536
    })();
537

538
  const getByIndicesImplementation =
539
    rank < 2
540
      ? ''
541
      : `
542
  fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
543
    return ${getByOffset(`i2o_${name}(indices)`)};
544
  }`;
545

546
  const getImplementation =
547
    rank < 2
548
      ? ''
549
      : (() => {
550
          const functionParams = rankIdentity.map((i) => `d${i}: u32`).join(', ');
551
          const dimsParams = rankIdentity.map((i) => `d${i}`).join(', ');
552
          return `
553
  fn get_${name}(${functionParams}) -> ${valueType} {
554
    return get_${name}ByIndices(${indices(dimsParams)});
555
  }`;
556
        })();
557

558
  const get = (...indices: ReadonlyArray<number | string>) => {
559
    if (indices.length !== rank) {
560
      throw new Error(`indices length must be ${rank}`);
561
    }
562

563
    const normalizedIndices = indices.map(normalizeDim).join(',');
564

565
    if (rank === 0) {
566
      return getByOffset('0u');
567
    } else if (rank === 1) {
568
      return getByOffset(normalizedIndices[0]);
569
    } else {
570
      implementationUsed.get = true;
571
      implementationUsed.getByIndices = true;
572
      implementationUsed.indicesToOffset = true;
573
      return `get_${name}(${normalizedIndices})`;
574
    }
575
  };
576

577
  const getByIndices = (varIndices: string) => {
578
    if (rank < 2) {
579
      return getByOffset(varIndices);
580
    } else {
581
      implementationUsed.getByIndices = true;
582
      implementationUsed.indicesToOffset = true;
583
      return `get_${name}ByIndices(${varIndices})`;
584
    }
585
  };
586

587
  const setByIndicesImplementation =
588
    rank < 2
589
      ? ''
590
      : `
591
  fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
592
    ${setByOffset(`i2o_${name}(indices)`, 'value')}
593
  }`;
594

595
  const setImplementation =
596
    rank < 2
597
      ? ''
598
      : (() => {
599
          const functionParams = rankIdentity.map((i) => `d${i}: u32`).join(', ');
600
          const dimsParams = rankIdentity.map((i) => `d${i}`).join(', ');
601
          return `
602
  fn set_${name}(${functionParams}, value: ${valueType}) {
603
    set_${name}ByIndices(${indices(dimsParams)}, value);
604
  }`;
605
        })();
606

607
  const set = (...indicesAndValue: ReadonlyArray<number | string>) => {
608
    if (indicesAndValue.length !== rank + 1) {
609
      throw new Error(`indices length must be ${rank}`);
610
    }
611
    const value = indicesAndValue[rank];
612
    if (typeof value !== 'string') {
613
      throw new Error('value must be string');
614
    }
615

616
    const normalizedIndices = indicesAndValue.slice(0, rank).map(normalizeDim).join(',');
617

618
    if (rank === 0) {
619
      return setByOffset('0u', value);
620
    } else if (rank === 1) {
621
      return setByOffset(normalizedIndices[0], value);
622
    } else {
623
      implementationUsed.set = true;
624
      implementationUsed.setByIndices = true;
625
      implementationUsed.indicesToOffset = true;
626
      return `set_${name}(${normalizedIndices}, ${value})`;
627
    }
628
  };
629

630
  const setByIndices = (varIndices: string, value: string) => {
631
    if (rank < 2) {
632
      return setByOffset(varIndices, value);
633
    } else {
634
      implementationUsed.setByIndices = true;
635
      implementationUsed.indicesToOffset = true;
636
      return `set_${name}ByIndices(${varIndices}, ${value});`;
637
    }
638
  };
639

640
  const impl = () => {
641
    const impls = [];
642
    let needShapeStrides = false;
643
    if (implementationUsed.offsetToIndices) {
644
      impls.push(offsetToIndicesImplementation);
645
      needShapeStrides = true;
646
    }
647
    if (implementationUsed.indicesToOffset) {
648
      impls.push(indicesToOffsetImplementation);
649
      needShapeStrides = true;
650
    }
651
    if (implementationUsed.broadcastedIndicesToOffset) {
652
      Object.values(broadcastedIndicesToOffsetImplementation).forEach((impl) => impls.push(impl));
653
      needShapeStrides = true;
654
    }
655
    if (implementationUsed.set) {
656
      impls.push(setImplementation);
657
      needShapeStrides = true;
658
    }
659
    if (implementationUsed.setByIndices) {
660
      impls.push(setByIndicesImplementation);
661
      needShapeStrides = true;
662
    }
663
    if (implementationUsed.get) {
664
      impls.push(getImplementation);
665
      needShapeStrides = true;
666
    }
667
    if (implementationUsed.getByIndices) {
668
      impls.push(getByIndicesImplementation);
669
      needShapeStrides = true;
670
    }
671
    if (!useUniform && needShapeStrides) {
672
      impls.unshift(
673
        `const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`,
674
        `const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`,
675
      );
676
    }
677
    return impls.join('\n');
678
  };
679

680
  return {
681
    impl,
682
    type,
683
    offsetToIndices,
684
    indicesToOffset,
685
    broadcastedIndicesToOffset,
686
    indices,
687
    indicesGet,
688
    indicesSet,
689
    set,
690
    setByOffset,
691
    setByIndices,
692
    get,
693
    getByOffset,
694
    getByIndices,
695
    // isVec4,
696
    usage,
697
    name,
698
    strides,
699
    shape,
700
    rank,
701
  };
702
};
703

704
/**
705
 * Create a IndicesHelper for an input.
706
 *
707
 * @param name - the name of the input.
708
 * @param type - the tensor type of the input.
709
 * @param shapeOrRank - the tensor shape or the rank of the input.
710
 * @param components - the number of components of the input. available values are 1, 2, 3, 4. default is 1.
711
 * @returns an IndicesHelper for the input.
712
 */
713
export const inputVariable = (
714
  name: string,
715
  type: number,
716
  shapeOrRank: number | readonly number[],
717
  components: 1 | 2 | 3 | 4 = 1,
718
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'input', components);
719

720
/**
721
 * Create a IndicesHelper for an output.
722
 *
723
 * @param name - the name of the output.
724
 * @param type - the tensor type of the output.
725
 * @param shapeOrRank - the tensor shape or the rank of the output.
726
 * @param components - the number of components of the output. available values are 1, 2, 3, 4. default is 1.
727
 * @returns an IndicesHelper for the output.
728
 */
729
export const outputVariable = (
730
  name: string,
731
  type: number,
732
  shapeOrRank: number | readonly number[],
733
  components: 1 | 2 | 3 | 4 = 1,
734
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'output', components);
735

736
/**
737
 * Create a IndicesHelper for an internal variable.
738
 *
739
 * @param name - the name of the variable.
740
 * @param type - the tensor type of the variable.
741
 * @param shapeOrRank - the tensor shape or the rank of the variable.
742
 * @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1.
743
 * @returns an IndicesHelper for the variable.
744
 */
745
export const internalVariable = (
746
  name: string,
747
  type: number,
748
  shapeOrRank: number | readonly number[],
749
  components: 1 | 2 | 3 | 4 = 1,
750
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components);
751

752
export type UniformDataElementType = 'u32' | 'f16' | 'f32' | 'i32';
753
export type UniformsArrayType = Array<{ name: string; type: UniformDataElementType; length?: number }>;
754

755
/**
756
 * A ShaderHelper is a helper class for generating WGSL code.
757
 */
758
export interface ShaderHelper {
759
  /**
760
   * A helper function to generate the start of main function in WGSL source code.
761
   *
762
   * @example
763
   * const getShaderSource = (shaderHelper: ShaderHelper) => `
764
   *  ...
765
   *
766
   *  ${shaderHelper.mainStart()}
767
   *    // your code here inside main() function
768
   *    ...
769
   *  }
770
   * `;
771
   *
772
   * @param workgroupSize - an optional workgroup size. default is WORKGROUP_SIZE.
773
   */
774
  mainStart(workgroupSize?: number | [number, number, number]): string;
775

776
  /**
777
   * A helper function to generate the code snippet for guarding against out-of-bounds size.
778
   *
779
   * @example
780
   * const getShaderSource = (shaderHelper: ShaderHelper) => `
781
   *  ...
782
   *
783
   *  ${shaderHelper.mainStart()}
784
   *    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
785
   *
786
   *    // your code here inside main() function
787
   *    ...
788
   *  }
789
   * `;
790
   *
791
   * @param size - the size of the data to guard against. can be a number or a string (WGSL `u32` expression).
792
   */
793
  guardAgainstOutOfBoundsWorkgroupSizes(size: unknown): string;
794

795
  /**
796
   * A helper function to generate the code snippet for declaring multiple inputs or outputs.
797
   *
798
   * @param variables - an array of IndicesHelper for the variables.
799
   */
800
  declareVariables(...variables: IndicesHelper[]): string;
801

802
  /**
803
   * A helper function to register one uniform. Can be called multiple times to register multiple uniforms.
804
   *
805
   * @param name - the name of the uniform.
806
   * @param type - the type of the uniform.
807
   * @param length - the length of the uniform, default to 1 when it is not provided.
808
   */
809
  registerUniform(name: string, type: string, length?: number): ShaderHelper;
810

811
  /**
812
   * A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms.
813
   *
814
   * @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and
815
   *     `type`.
816
   */
817
  registerUniforms(uniforms: UniformsArrayType): ShaderHelper;
818

819
  /**
820
   * A helper function to register multiple internal variables. Can be called multiple times to register multiple
821
   * internal variables.
822
   *
823
   * @param variables - an array of IndicesHelper for the variables.
824
   */
825
  registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper;
826
}
827

828
class ShaderHelperImpl implements ShaderHelper {
829
  constructor(
830
    private normalizedDispatchGroup: [number, number, number],
831
    private limits: GPUSupportedLimits,
832
  ) {}
833

834
  guardAgainstOutOfBoundsWorkgroupSizes(size: number | string): string {
835
    // Guard against out-of-bounds work group sizes
836
    const sizeInCode = typeof size === 'number' ? `${size}u` : size;
837
    return `if (global_idx >= ${sizeInCode}) { return; }`;
838
  }
839

840
  mainStart(workgroupSize: number | [number, number, number] = WORKGROUP_SIZE) {
841
    const workgroupSizeX = typeof workgroupSize === 'number' ? workgroupSize : workgroupSize[0];
842
    const workgroupSizeY = typeof workgroupSize === 'number' ? 1 : workgroupSize[1];
843
    const workgroupSizeZ = typeof workgroupSize === 'number' ? 1 : workgroupSize[2];
844

845
    if (
846
      workgroupSizeX > this.limits.maxComputeWorkgroupSizeX ||
847
      workgroupSizeY > this.limits.maxComputeWorkgroupSizeY ||
848
      workgroupSizeZ > this.limits.maxComputeWorkgroupSizeZ
849
    ) {
850
      throw new Error(
851
        `workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${
852
          workgroupSizeZ
853
        }] exceeds the maximum workgroup size [${this.limits.maxComputeWorkgroupSizeX}, ${
854
          this.limits.maxComputeWorkgroupSizeY
855
        }, ${this.limits.maxComputeWorkgroupSizeZ}].`,
856
      );
857
    }
858

859
    if (workgroupSizeX * workgroupSizeY * workgroupSizeZ > this.limits.maxComputeInvocationsPerWorkgroup) {
860
      throw new Error(
861
        `workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${
862
          workgroupSizeZ
863
        }] exceeds the maximum workgroup invocations ${this.limits.maxComputeInvocationsPerWorkgroup}.`,
864
      );
865
    }
866

867
    const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
868
    const paramList = is1DimensionDispatch
869
      ? `@builtin(global_invocation_id) global_id : vec3<u32>,
870
    @builtin(workgroup_id) workgroup_id : vec3<u32>,
871
    @builtin(local_invocation_id) local_id : vec3<u32>`
872
      : `@builtin(global_invocation_id) global_id : vec3<u32>,
873
                                             @builtin(local_invocation_id) local_id : vec3<u32>,
874
    @builtin(local_invocation_index) local_idx : u32,
875
    @builtin(workgroup_id) workgroup_id : vec3<u32>,
876
    @builtin(num_workgroups) num_workgroups : vec3<u32>`;
877
    const globalIdxDefinition = is1DimensionDispatch
878
      ? 'let global_idx = global_id.x; let local_idx = local_id.x;'
879
      : `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
880
          workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
881
            workgroupSizeX * workgroupSizeY * workgroupSizeZ
882
          }u + local_idx;`;
883

884
    return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
885
  fn main(${paramList}) {
886
    ${globalIdxDefinition}
887
  `;
888
  }
889

890
  private appendVariableUniforms(variable: IndicesHelper): void {
891
    if (variable.rank !== 0) {
892
      if (variable.shape.startsWith('uniforms.')) {
893
        this.uniforms.push({ name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank });
894
      }
895
      if (variable.strides.startsWith('uniforms.')) {
896
        this.uniforms.push({ name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank });
897
      }
898
    }
899
  }
900

901
  private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
902
    if (variable.usage === 'internal') {
903
      throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.');
904
    }
905
    this.variables.push(variable);
906
    this.appendVariableUniforms(variable);
907

908
    const access = variable.usage === 'input' ? 'read' : 'read_write';
909
    const storageType = variable.type.storage;
910
    return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
911
  }
912

913
  declareVariables(...variables: IndicesHelper[]): string {
914
    return variables.map((v) => this.declareVariable(v, this.variableIndex++)).join('\n');
915
  }
916

917
  private registerInternalVariable(variable: IndicesHelper): void {
918
    if (variable.usage !== 'internal') {
919
      throw new Error(
920
        'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.',
921
      );
922
    }
923

924
    this.internalVariables.push(variable);
925
    this.appendVariableUniforms(variable);
926
  }
927

928
  registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper {
929
    variables.forEach((v) => this.registerInternalVariable(v));
930
    return this;
931
  }
932

933
  registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper {
934
    this.uniforms.push({ name, type, length });
935
    return this;
936
  }
937

938
  registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper {
939
    this.uniforms = this.uniforms.concat(additionalUniforms);
940
    return this;
941
  }
942

943
  private internalVariables: IndicesHelper[] = [];
944
  private variables: IndicesHelper[] = [];
945
  private uniforms: UniformsArrayType = [];
946
  private uniformDeclaration(): string {
947
    if (this.uniforms.length === 0) {
948
      return '';
949
    }
950

951
    const uniformSnippets: string[] = [];
952
    for (const { name, type, length } of this.uniforms) {
953
      if (length && length > 4) {
954
        if (type === 'f16') {
955
          uniformSnippets.push(`@align(16) ${name}:array<mat2x4<${type}>, ${Math.ceil(length / 8)}>`);
956
        } else {
957
          uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
958
        }
959
      } else {
960
        const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
961
        uniformSnippets.push(`${name}:${typeTemp}`);
962
      }
963
    }
964

965
    return `
966
      struct Uniforms { ${uniformSnippets.join(', ')} };
967
      @group(0) @binding(${this.variableIndex}) var<uniform> uniforms: Uniforms;`;
968
  }
969
  private variableIndex = 0;
970

971
  /**
972
   * Get additional implementation that needs to be added to the shader source.
973
   */
974
  get additionalImplementations(): string {
975
    return (
976
      this.uniformDeclaration() +
977
      this.variables.map((i) => i.impl()).join('\n') +
978
      this.internalVariables.map((i) => i.impl()).join('\n')
979
    );
980
  }
981

982
  /**
983
   * Get the variable info of the shader program.
984
   */
985
  get variablesInfo(): ProgramUniformVariableInfo[] | undefined {
986
    if (this.uniforms.length === 0) {
987
      return undefined;
988
    }
989

990
    const uniformWgslTypeToDataType = (type: UniformDataElementType) =>
991
      [DataType.uint32, DataType.float16, DataType.float, DataType.int32][['u32', 'f16', 'f32', 'i32'].indexOf(type)];
992
    return this.uniforms.map((u) => [uniformWgslTypeToDataType(u.type), u.length ?? 1]);
993
  }
994
}
995

996
export const createShaderHelper = (dispatchGroup: [number, number, number], limits: GPUSupportedLimits) =>
997
  new ShaderHelperImpl(dispatchGroup, limits);
998

999
/**
1000
 * This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40
1001
 * Returns the dimensions in the input shape that are broadcasted to
1002
 * produce the provided output shape.
1003
 *
1004
 * The returned dimensions are 0-indexed and sorted. An example:
1005
 * inShape = [4, 1, 3]
1006
 * outShape = [5, 4, 3, 3]
1007
 * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
1008
 */
1009
export const getBroadcastDims = (inShape: readonly number[], outShape: readonly number[]): number[] => {
1010
  const inRank = inShape.length;
1011
  const dims: number[] = [];
1012
  for (let i = 0; i < inRank; i++) {
1013
    const dim = inRank - 1 - i;
1014
    const a = inShape[dim] || 1;
1015
    const b = outShape[outShape.length - 1 - i] || 1;
1016
    if (b > 1 && a === 1) {
1017
      dims.unshift(dim);
1018
    }
1019
  }
1020
  return dims;
1021
};
1022

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.