onnxruntime

Форк
0
/
backend-webgpu.ts 
915 строк · 33.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END } from 'onnxruntime-common';
5

6
import { DataType, tensorDataTypeEnumToString } from '../wasm-common';
7

8
import { configureLogger, LOG_DEBUG } from './log';
9
import { createView, TensorView } from './tensor-view';
10
import { createGpuDataManager, downloadGpuData, GpuDataManager } from './webgpu/gpu-data-manager';
11
import { RunFunction, WEBGPU_OP_RESOLVE_RULES } from './webgpu/op-resolve-rules';
12
import { ProgramManager } from './webgpu/program-manager';
13
import {
14
  AdapterInfo,
15
  ComputeContext,
16
  GpuArchitecture,
17
  GpuData,
18
  GpuVendor,
19
  ProgramInfo,
20
  ProgramInputTensorInfoDependency,
21
  SessionState,
22
  TimestampQuery,
23
} from './webgpu/types';
24

25
interface CommandInfo {
26
  readonly kernelId: number;
27
  readonly computePipeline: GPUComputePipeline;
28
  readonly bindGroup: GPUBindGroup;
29
  readonly dispatchGroup: [number, number, number];
30
}
31

32
interface KernelInfo {
33
  readonly kernelType: string;
34
  readonly kernelName: string;
35
  readonly kernelEntry: RunFunction;
36
  readonly attributes: [((attribute: unknown) => unknown) | undefined, unknown];
37
}
38

39
interface PendingKernelInfo {
40
  readonly kernelId: number;
41
  readonly programName: string;
42
  readonly inputTensorViews: readonly TensorView[];
43
  readonly outputTensorViews: readonly TensorView[];
44
}
45

46
const getProgramInputTensorInfoDependencyKey = (
47
  inputTensors: readonly TensorView[],
48
  inputDependencies: readonly ProgramInputTensorInfoDependency[],
49
): string => {
50
  if (inputDependencies.length !== inputTensors.length) {
51
    throw new Error(
52
      `inputDependencies length ${inputDependencies.length} is not equal to inputTensors length ${
53
        inputTensors.length
54
      }.`,
55
    );
56
  }
57

58
  const inputInfos: string[] = [];
59
  for (let i = 0; i < inputTensors.length; ++i) {
60
    const type = inputTensors[i].dataType;
61
    switch (inputDependencies[i]) {
62
      case 'none': {
63
        inputInfos.push('');
64
        break;
65
      }
66
      case 'type': {
67
        inputInfos.push(`${type}`);
68
        break;
69
      }
70
      case 'rank': {
71
        const rank = inputTensors[i].dims.length;
72
        inputInfos.push(`${type};${rank}`);
73
        break;
74
      }
75
      case 'dims': {
76
        const dims = inputTensors[i].dims.join(',');
77
        inputInfos.push(`${type};${dims}`);
78
        break;
79
      }
80
      default:
81
        throw new Error(`unsupported input dependency: ${inputDependencies[i]}`);
82
    }
83
  }
84

85
  return inputInfos.join('|');
86
};
87

88
/**
89
 * get a unique key representing the program from the program info, input shapes and types.
90
 *
91
 * @returns a unique key is a shorter string than the shader source, which contains all the information to identify a
92
 * program. if the key is the same, the program shader source should be the same, so we can reuse the program.
93
 *
94
 */
95
const getProgramInfoUniqueKey = (
96
  programInfo: ProgramInfo,
97
  inputTensors: readonly TensorView[],
98
  is1DimensionDispatch: boolean,
99
): string => {
100
  // final key format:
101
  // <PROGRAM_NAME>[<PROGRAM_CUSTOM_CACHE_HINT>]:is1DimensionDispatch:<INPUTS_INFO_0>|<INPUTS_INFO_1>|...
102
  let key = programInfo.name;
103
  if (programInfo.shaderCache?.hint) {
104
    key += '[' + programInfo.shaderCache.hint + ']';
105
  }
106
  key +=
107
    ':' +
108
    is1DimensionDispatch +
109
    `:${getProgramInputTensorInfoDependencyKey(
110
      inputTensors,
111
      programInfo.shaderCache?.inputDependencies ??
112
        new Array<ProgramInputTensorInfoDependency>(inputTensors.length).fill('dims'),
113
    )}`;
114
  return key;
115
};
116

117
class AdapterInfoImpl implements AdapterInfo {
118
  readonly architecture?: string;
119
  readonly vendor?: string;
120

121
  constructor(adapterInfo: GPUAdapterInfo) {
122
    if (adapterInfo) {
123
      this.architecture = adapterInfo.architecture;
124
      this.vendor = adapterInfo.vendor;
125
    }
126
  }
127

128
  isArchitecture(architecture: GpuArchitecture): boolean {
129
    return this.architecture === architecture;
130
  }
131

132
  isVendor(vendor: GpuVendor): boolean {
133
    return this.vendor === vendor;
134
  }
135
}
136

137
/**
138
 * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
139
 * the first parameter so that it is stored for future use.
140
 */
141
export class WebGpuBackend {
142
  adapterInfo: AdapterInfoImpl;
143
  device: GPUDevice;
144
  /**
145
   * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
146
   */
147
  gpuDataManager: GpuDataManager;
148
  /**
149
   * an instance of ProgramManager to build and run WebGPU compute shader program, and manage a ProgramKey -> Program
150
   * artifacts mapping
151
   */
152
  programManager: ProgramManager;
153

154
  /**
155
   * representing the session ID of which is currently being run.
156
   * `null` means no session is being run.
157
   * only valid when session.run is executed.
158
   */
159
  currentSessionId: number | null = null;
160

161
  /**
162
   * representing the kernel ID of which is currently being computed (CPU code perspective).
163
   * `null` means no kernel is being computed.
164
   * only one kernel can be computed at a moment.
165
   */
166
  currentKernelId: number | null = null;
167
  /**
168
   * a list of temporary GPU data for the current kernel. should release when the kernel done computation.
169
   */
170
  private temporaryData: GpuData[];
171
  /**
172
   * a KernelID -> a GPU data list, which stores persistent GPU data owned by the specific kernel.
173
   */
174
  private kernelPersistentData: Map<number, GpuData[]>;
175
  /**
176
   * a KernelID -> a custom data, which stores custom data owned by the specific kernel.
177
   */
178
  private kernelCustomData: Map<number, { [key: string]: unknown }>;
179
  /**
180
   * get the custom data of the current kernel
181
   */
182
  get currentKernelCustomData(): { [key: string]: unknown } {
183
    if (this.currentKernelId === null) {
184
      throw new Error('currentKernelCustomData(): currentKernelId is null. (should not happen)');
185
    }
186

187
    let data = this.kernelCustomData.get(this.currentKernelId);
188
    if (!data) {
189
      data = {};
190
      this.kernelCustomData.set(this.currentKernelId, data);
191
    }
192

193
    return data;
194
  }
195

196
  // KernelID -> kernelInfo mapping
197
  kernels: Map<number, KernelInfo>;
198
  private commandEncoder: GPUCommandEncoder | null = null;
199
  private computePassEncoder: GPUComputePassEncoder | null = null;
200
  maxDispatchNumber = 16;
201
  pendingDispatchNumber = 0;
202

203
  // info of kernels pending submission for a single batch
204
  private pendingKernels: PendingKernelInfo[] = [];
205
  // queryReadBuffer -> pendingKernels mapping for all the batches
206
  private pendingQueries: Map<GPUBuffer, PendingKernelInfo[]> = new Map();
207
  private queryResolveBuffer?: GPUBuffer;
208
  private querySet?: GPUQuerySet;
209
  private queryTimeBase?: bigint;
210
  queryType: TimestampQuery;
211

212
  env: Env;
213
  sessionStatus: SessionState = 'default';
214
  /**
215
   * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
216
   */
217
  capturedCommandList: Map<number, CommandInfo[]> = new Map();
218

219
  /**
220
   * a SessionID -> PendingKernelInfo[] mapping for profiling.
221
   */
222
  private capturedPendingKernels: Map<number, PendingKernelInfo[]> = new Map();
223

224
  /**
225
   * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
226
   */
227
  sessionExternalDataMapping: Map<number, Map<number, [number, GPUBuffer]>> = new Map();
228

229
  async initialize(env: Env, adapter: GPUAdapter): Promise<void> {
230
    this.env = env;
231
    const requiredFeatures: GPUFeatureName[] = [];
232
    const deviceDescriptor: GPUDeviceDescriptor = {
233
      requiredLimits: {
234
        maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize,
235
        maxComputeWorkgroupsPerDimension: adapter.limits.maxComputeWorkgroupsPerDimension,
236
        maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize,
237
        maxBufferSize: adapter.limits.maxBufferSize,
238
        maxComputeInvocationsPerWorkgroup: adapter.limits.maxComputeInvocationsPerWorkgroup,
239
        maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
240
        maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY,
241
        maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ,
242
      },
243
      requiredFeatures,
244
    };
245

246
    if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
247
      requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
248
    } else if (adapter.features.has('timestamp-query')) {
249
      requiredFeatures.push('timestamp-query');
250
    }
251
    if (adapter.features.has('shader-f16')) {
252
      requiredFeatures.push('shader-f16');
253
    }
254

255
    this.device = await adapter.requestDevice(deviceDescriptor);
256
    this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo()));
257
    this.gpuDataManager = createGpuDataManager(this);
258
    this.programManager = new ProgramManager(this);
259
    this.kernels = new Map();
260
    this.kernelPersistentData = new Map();
261
    this.kernelCustomData = new Map();
262

263
    // set up flags for logger
264
    configureLogger(env.logLevel!, !!env.debug);
265

266
    // TODO: set up flags
267

268
    this.device.onuncapturederror = (ev) => {
269
      if (ev.error instanceof GPUValidationError) {
270
        // eslint-disable-next-line no-console
271
        console.error(`An uncaught WebGPU validation error was raised: ${ev.error.message}`);
272
      }
273
    };
274

275
    Object.defineProperty(this.env.webgpu, 'device', {
276
      value: this.device,
277
      writable: false,
278
      enumerable: true,
279
      configurable: false,
280
    });
281
    Object.defineProperty(this.env.webgpu, 'adapter', {
282
      value: adapter,
283
      writable: false,
284
      enumerable: true,
285
      configurable: false,
286
    });
287

288
    // init queryType, which is necessary for InferenceSession.create
289
    this.setQueryType();
290
  }
291

292
  dispose(): void {
293
    if (typeof this.querySet !== 'undefined') {
294
      this.querySet.destroy();
295
    }
296
    this.gpuDataManager.dispose();
297
  }
298

299
  getCommandEncoder(): GPUCommandEncoder {
300
    if (!this.commandEncoder) {
301
      this.commandEncoder = this.device.createCommandEncoder();
302
    }
303
    return this.commandEncoder;
304
  }
305

306
  getComputePassEncoder(): GPUComputePassEncoder {
307
    if (!this.computePassEncoder) {
308
      const commandEncoder = this.getCommandEncoder();
309
      const computePassDescriptor: GPUComputePassDescriptor = {};
310

311
      if (this.queryType === 'at-passes') {
312
        computePassDescriptor.timestampWrites = {
313
          querySet: this.querySet!,
314
          beginningOfPassWriteIndex: this.pendingDispatchNumber * 2,
315
          endOfPassWriteIndex: this.pendingDispatchNumber * 2 + 1,
316
        };
317
      }
318

319
      this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor);
320
    }
321
    return this.computePassEncoder;
322
  }
323

324
  endComputePass(): void {
325
    if (this.computePassEncoder) {
326
      this.computePassEncoder.end();
327
      this.computePassEncoder = null;
328
    }
329
  }
330

331
  flush(): void {
332
    if (!this.commandEncoder) {
333
      return;
334
    }
335

336
    TRACE_FUNC_BEGIN();
337

338
    this.endComputePass();
339
    let queryReadBuffer: GPUBuffer;
340
    if (this.queryType !== 'none') {
341
      this.commandEncoder.resolveQuerySet(
342
        this.querySet!,
343
        0,
344
        this.pendingDispatchNumber * 2,
345
        this.queryResolveBuffer!,
346
        0,
347
      );
348

349
      queryReadBuffer = this.device.createBuffer(
350
        // eslint-disable-next-line no-bitwise
351
        { size: this.pendingDispatchNumber * 2 * 8, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST },
352
      );
353

354
      this.pendingQueries.set(queryReadBuffer, this.pendingKernels);
355
      this.pendingKernels = [];
356
      this.commandEncoder.copyBufferToBuffer(
357
        this.queryResolveBuffer!,
358
        0,
359
        queryReadBuffer,
360
        0,
361
        this.pendingDispatchNumber * 2 * 8,
362
      );
363
    }
364

365
    this.device.queue.submit([this.commandEncoder.finish()]);
366
    this.gpuDataManager.refreshPendingBuffers();
367
    this.commandEncoder = null;
368
    this.pendingDispatchNumber = 0;
369

370
    if (this.queryType !== 'none') {
371
      void queryReadBuffer!.mapAsync(GPUMapMode.READ).then(() => {
372
        const mappedData = new BigUint64Array(queryReadBuffer.getMappedRange());
373
        const pendingKernels = this.pendingQueries.get(queryReadBuffer)!;
374
        for (let i = 0; i < mappedData.length / 2; i++) {
375
          const pendingKernelInfo = pendingKernels[i];
376
          const kernelId = pendingKernelInfo.kernelId;
377
          const kernelInfo = this.kernels.get(kernelId)!;
378
          const kernelType = kernelInfo.kernelType;
379
          const kernelName = kernelInfo.kernelName;
380
          const programName = pendingKernelInfo.programName;
381
          const inputTensorViews = pendingKernelInfo.inputTensorViews;
382
          const outputTensorViews = pendingKernelInfo.outputTensorViews;
383
          const startTimeU64 = mappedData[i * 2];
384
          const endTimeU64 = mappedData[i * 2 + 1];
385

386
          if (typeof this.queryTimeBase === 'undefined') {
387
            this.queryTimeBase = startTimeU64;
388
          }
389

390
          const startTime = Number(startTimeU64 - this.queryTimeBase);
391
          const endTime = Number(endTimeU64 - this.queryTimeBase);
392

393
          if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) {
394
            throw new RangeError('incorrect timestamp range');
395
          }
396

397
          if (this.env.webgpu.profiling?.ondata) {
398
            this.env.webgpu.profiling.ondata({
399
              version: 1,
400
              inputsMetadata: inputTensorViews.map((value) => ({
401
                dims: value.dims,
402
                dataType: tensorDataTypeEnumToString(value.dataType),
403
              })),
404
              outputsMetadata: outputTensorViews.map((value) => ({
405
                dims: value.dims,
406
                dataType: tensorDataTypeEnumToString(value.dataType),
407
              })),
408
              kernelId,
409
              kernelType,
410
              kernelName,
411
              programName,
412
              startTime,
413
              endTime,
414
            });
415
          } else {
416
            // if no callback is provided, print the profiling message to console
417
            let inputShapes = '';
418
            inputTensorViews.forEach((value, i) => {
419
              inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
420
            });
421
            let outputShapes = '';
422
            outputTensorViews.forEach((value, i) => {
423
              outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
424
            });
425
            // eslint-disable-next-line no-console
426
            console.log(
427
              `[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${
428
                outputShapes
429
              }execution time: ${endTime - startTime} ns`,
430
            );
431
          }
432
          TRACE('GPU', `${programName}::${startTimeU64}::${endTimeU64}`);
433
        }
434
        queryReadBuffer.unmap();
435
        this.pendingQueries.delete(queryReadBuffer);
436
      });
437
    }
438
    TRACE_FUNC_END();
439
  }
440

441
  /**
442
   * run a WebGPU program.
443
   * @param program a ProgramInfo instance
444
   * @param inputTensorViews a TensorView array. each element represents a value already exists in GPU.
445
   * @param outputIndices an indices array. each element can be either -1 (temporary data), -2 (persistent data) or an
446
   * index to the kernel's output.
447
   * @param createKernelOutput a callback function that create a value to kernel's output with the given index
448
   * @param createIntermediateOutput a callback function that create a value as a intermediate value, either temporary
449
   * or persistent (owned by the current kernel)
450
   * @returns a TensorView array representing the result.
451
   */
452
  run(
453
    program: ProgramInfo,
454
    inputTensorViews: readonly TensorView[],
455
    outputIndices: readonly number[],
456
    createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView,
457
    createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView,
458
    outputCount: number,
459
  ): TensorView[] {
460
    TRACE_FUNC_BEGIN(program.name);
461
    // create info for inputs
462
    const inputDatas: GpuData[] = [];
463
    for (let i = 0; i < inputTensorViews.length; ++i) {
464
      const data = inputTensorViews[i].data;
465
      // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
466
      if (data === 0) {
467
        continue;
468
      }
469
      const gpuData = this.gpuDataManager.get(data);
470
      if (!gpuData) {
471
        throw new Error(`no GPU data for input: ${data}`);
472
      }
473
      inputDatas.push(gpuData);
474
    }
475

476
    const { outputs, dispatchGroup, programUniforms } = program.getRunData(inputTensorViews);
477

478
    // check output indices
479
    const validatedOutputIndices = outputIndices.length === 0 ? outputs.map((_, i) => i) : outputIndices;
480
    if (validatedOutputIndices.length !== outputs.length) {
481
      throw new Error(`Output size ${validatedOutputIndices.length} must be equal to ${outputs.length}.`);
482
    }
483

484
    // create info for outputs
485
    const outputTensorViews: TensorView[] = [];
486
    const outputDatas: GpuData[] = [];
487
    for (let i = 0; i < outputs.length; ++i) {
488
      // value -1 and -2 are used for creating temporary and persistent outputs.
489
      // value -3 is used for placeholder output. So -3, -2, -1 and 0, 1, 2, ... are valid
490
      // output indices. see type definition of ComputeContextInputsOutputsMapping for more details.
491
      if (
492
        !Number.isInteger(validatedOutputIndices[i]) ||
493
        validatedOutputIndices[i] < -3 ||
494
        validatedOutputIndices[i] >= outputCount
495
      ) {
496
        throw new Error(`Invalid output index: ${validatedOutputIndices[i]}`);
497
      }
498
      if (validatedOutputIndices[i] === -3) {
499
        continue;
500
      }
501
      const isTemporary = validatedOutputIndices[i] === -1;
502
      const isPersistent = validatedOutputIndices[i] === -2;
503
      const tensorView =
504
        isTemporary || isPersistent
505
          ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims)
506
          : createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
507
      outputTensorViews.push(tensorView);
508
      // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
509
      if (tensorView.data === 0) {
510
        continue;
511
      }
512
      const gpuData = this.gpuDataManager.get(tensorView.data);
513
      if (!gpuData) {
514
        throw new Error(`no GPU data for output: ${tensorView.data}`);
515
      }
516
      if (isTemporary) {
517
        this.temporaryData.push(gpuData);
518
      }
519
      if (isPersistent) {
520
        let persistentData = this.kernelPersistentData.get(this.currentKernelId!);
521
        if (!persistentData) {
522
          persistentData = [];
523
          this.kernelPersistentData.set(this.currentKernelId!, persistentData);
524
        }
525
        persistentData.push(gpuData);
526
      }
527
      outputDatas.push(gpuData);
528
    }
529

530
    // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
531
    // zero-sized tensors.
532
    if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) {
533
      // if all outputs are zero-sized tensors, there is no need to run the program.
534
      if (outputDatas.length === 0) {
535
        TRACE_FUNC_END(program.name);
536
        return outputTensorViews;
537
      }
538
      // if some outputs are zero-sized tensors, report an error.
539
      //
540
      // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
541
      // If we see such use case, we need to make a change here to support it.
542
      throw new Error(
543
        `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`,
544
      );
545
    }
546

547
    // load uniforms
548
    // TODO: add cache for uniform (is it necessary?)
549
    //
550
    let uniformBufferBinding: GPUBindingResource | undefined;
551
    if (programUniforms) {
552
      let currentOffset = 0;
553
      const offsets: number[] = [];
554

555
      programUniforms.forEach((v) => {
556
        const data = typeof v.data === 'number' ? [v.data] : v.data;
557
        if (data.length === 0) {
558
          return;
559
        }
560
        // https://www.w3.org/TR/WGSL/#alignof
561
        const sizeOfElement = v.type === DataType.float16 ? 2 : 4;
562
        let sizeOfVecOrMat;
563
        let baseAlignment;
564
        if (v.type === DataType.float16) {
565
          baseAlignment = data.length > 4 ? 16 : data.length > 2 ? 8 : data.length * sizeOfElement;
566
          sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
567
        } else {
568
          baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
569
          sizeOfVecOrMat = 16;
570
        }
571
        currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
572
        offsets.push(currentOffset);
573
        // For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
574
        // N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
575
        // SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
576
        // array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
577
        // length is N * SizeOf(mat2x4<f16>).
578
        const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4;
579
        currentOffset +=
580
          data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : data.length * sizeOfElement;
581
      });
582

583
      // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
584
      // maxAlignmentOfField to 16 since the underlying buffer has been rounded up to 16.
585
      const maxAlignmentOfField = 16;
586
      currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField;
587
      const arrayBuffer = new ArrayBuffer(currentOffset);
588
      programUniforms.forEach((v, i) => {
589
        const offset = offsets[i];
590
        const data = typeof v.data === 'number' ? [v.data] : v.data;
591
        if (v.type === DataType.int32) {
592
          new Int32Array(arrayBuffer, offset, data.length).set(data);
593
        } else if (v.type === DataType.uint32) {
594
          new Uint32Array(arrayBuffer, offset, data.length).set(data);
595
        } else if (v.type === DataType.float16) {
596
          new Uint16Array(arrayBuffer, offset, data.length).set(data);
597
        } else if (v.type === DataType.float) {
598
          new Float32Array(arrayBuffer, offset, data.length).set(data);
599
        } else {
600
          throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`);
601
        }
602
      });
603

604
      const uniformBufferData =
605
        // eslint-disable-next-line no-bitwise
606
        this.gpuDataManager.create(currentOffset, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM);
607
      this.device.queue.writeBuffer(uniformBufferData.buffer, 0, arrayBuffer, 0, currentOffset);
608
      this.gpuDataManager.release(uniformBufferData.id);
609
      uniformBufferBinding = { offset: 0, size: currentOffset, buffer: uniformBufferData.buffer };
610
    }
611

612
    const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(dispatchGroup);
613
    const is1DimensionDispatch = normalizedDispatchGroup[1] === 1 && normalizedDispatchGroup[2] === 1;
614
    // get program info
615
    const key = getProgramInfoUniqueKey(program, inputTensorViews, is1DimensionDispatch);
616
    let artifact = this.programManager.getArtifact(key);
617
    if (!artifact) {
618
      artifact = this.programManager.build(program, normalizedDispatchGroup);
619
      this.programManager.setArtifact(key, artifact);
620
      LOG_DEBUG('info', () => `[artifact] key: ${key}, programName: ${program.name}`);
621
    }
622

623
    // validate uniform variables
624
    if (programUniforms && artifact.uniformVariablesInfo) {
625
      if (programUniforms.length !== artifact.uniformVariablesInfo.length) {
626
        throw new Error(
627
          `Uniform variables count mismatch: expect ${artifact.uniformVariablesInfo.length}, got ${
628
            programUniforms.length
629
          } in program "${artifact.programInfo.name}".`,
630
        );
631
      }
632
      for (let i = 0; i < programUniforms.length; i++) {
633
        const uniform = programUniforms[i];
634
        const actualType = uniform.type;
635
        const actualLength = typeof uniform.data === 'number' ? 1 : uniform.data.length;
636
        const [type, length] = artifact.uniformVariablesInfo[i];
637
        if (actualType !== type || actualLength !== length) {
638
          throw new Error(
639
            `Uniform variable ${i} mismatch: expect type ${type} with size ${length}, got type ${
640
              actualType
641
            } with size ${actualLength} in program "${artifact.programInfo.name}".`,
642
          );
643
        }
644
      }
645
    }
646

647
    LOG_DEBUG(
648
      'info',
649
      () =>
650
        `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
651
          normalizedDispatchGroup[1]
652
        }x${normalizedDispatchGroup[2]}`,
653
    );
654

655
    if (this.queryType !== 'none' || this.sessionStatus === 'capturing') {
656
      const pendingKernelInfo: PendingKernelInfo = {
657
        kernelId: this.currentKernelId!,
658
        programName: artifact.programInfo.name,
659
        inputTensorViews,
660
        outputTensorViews,
661
      };
662
      this.pendingKernels.push(pendingKernelInfo);
663

664
      if (this.sessionStatus === 'capturing') {
665
        const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
666
        sessionPendingKernels!.push(pendingKernelInfo);
667
      }
668
    }
669

670
    this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);
671

672
    TRACE_FUNC_END(program.name);
673
    return outputTensorViews;
674
  }
675

676
  upload(gpuDataId: number, data: Uint8Array): void {
677
    this.gpuDataManager.upload(gpuDataId, data);
678
  }
679

680
  memcpy(src: number, dst: number): void {
681
    this.gpuDataManager.memcpy(src, dst);
682
  }
683

684
  async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise<void> {
685
    // the underlying buffer may be changed after the async function is called. so we use a getter function to make sure
686
    // the buffer is up-to-date.
687
    await this.gpuDataManager.download(gpuDataId, getTargetBuffer);
688
  }
689

690
  alloc(size: number): number {
691
    return this.gpuDataManager.create(size).id;
692
  }
693

694
  free(ptr: number): number {
695
    return this.gpuDataManager.release(ptr);
696
  }
697

698
  createKernel(kernelType: string, kernelId: number, attribute: unknown, kernelName: string): void {
699
    const op = WEBGPU_OP_RESOLVE_RULES.get(kernelType);
700
    if (!op) {
701
      throw new Error(`kernel not implemented: ${kernelType}`);
702
    }
703

704
    const kernelInfo: KernelInfo = {
705
      kernelType,
706
      kernelName,
707
      kernelEntry: op[0],
708
      attributes: [op[1], attribute],
709
    };
710
    this.kernels.set(kernelId, kernelInfo);
711
  }
712

713
  releaseKernel(kernelId: number): void {
714
    const persistentData = this.kernelPersistentData.get(kernelId);
715
    if (persistentData) {
716
      for (const data of persistentData) {
717
        this.gpuDataManager.release(data.id);
718
      }
719
      this.kernelPersistentData.delete(kernelId);
720
    }
721

722
    this.kernelCustomData.delete(kernelId);
723
    this.kernels.delete(kernelId);
724
  }
725

726
  computeKernel(kernelId: number, context: ComputeContext, errors: Array<Promise<string | null>>): number {
727
    const kernel = this.kernels.get(kernelId);
728
    if (!kernel) {
729
      throw new Error(`kernel not created: ${kernelId}`);
730
    }
731
    const kernelType = kernel.kernelType;
732
    const kernelName = kernel.kernelName;
733
    const kernelEntry = kernel.kernelEntry;
734
    const attributes = kernel.attributes;
735
    if (this.currentKernelId !== null) {
736
      throw new Error(`kernel "[${kernelType}] ${kernelName}" is not allowed to be called recursively`);
737
    }
738
    this.currentKernelId = kernelId;
739

740
    // parse attributes if necessary
741
    if (attributes[0]) {
742
      attributes[1] = attributes[0](attributes[1]);
743
      attributes[0] = undefined;
744
    }
745

746
    LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${kernelType}] ${kernelName}"...`);
747

748
    const useErrorScope = this.env.debug;
749

750
    this.temporaryData = [];
751
    try {
752
      if (useErrorScope) {
753
        this.device.pushErrorScope('validation');
754
      }
755

756
      kernelEntry(context, attributes[1]);
757
      return 0; // ORT_OK
758
    } catch (e) {
759
      errors.push(Promise.resolve(`[WebGPU] Kernel "[${kernelType}] ${kernelName}" failed. ${e}`));
760
      return 1; // ORT_FAIL
761
    } finally {
762
      if (useErrorScope) {
763
        errors.push(
764
          this.device
765
            .popErrorScope()
766
            .then((err) =>
767
              err ? `GPU validation error for kernel "[${kernelType}] ${kernelName}": ${err.message}` : null,
768
            ),
769
        );
770
      }
771

772
      for (const data of this.temporaryData) {
773
        this.gpuDataManager.release(data.id);
774
      }
775
      this.temporaryData = [];
776
      this.currentKernelId = null;
777
    }
778
  }
779

780
  // #region external buffer
781
  registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number {
782
    let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
783
    if (!sessionInputOutputMapping) {
784
      sessionInputOutputMapping = new Map();
785
      this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping);
786
    }
787

788
    const previousBuffer = sessionInputOutputMapping.get(index);
789
    const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]);
790
    sessionInputOutputMapping.set(index, [id, buffer]);
791
    return id;
792
  }
793
  unregisterBuffers(sessionId: number): void {
794
    const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
795
    if (sessionInputOutputMapping) {
796
      sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1]));
797
      this.sessionExternalDataMapping.delete(sessionId);
798
    }
799
  }
800
  getBuffer(gpuDataId: number): GPUBuffer {
801
    const gpuData = this.gpuDataManager.get(gpuDataId);
802
    if (!gpuData) {
803
      throw new Error(`no GPU data for buffer: ${gpuDataId}`);
804
    }
805
    return gpuData.buffer;
806
  }
807
  createDownloader(
808
    gpuBuffer: GPUBuffer,
809
    size: number,
810
    type: Tensor.GpuBufferDataTypes,
811
  ): () => Promise<Tensor.DataType> {
812
    return async () => {
813
      const data = await downloadGpuData(this, gpuBuffer, size);
814
      return createView(data.buffer, type);
815
    };
816
  }
817
  // #endregion
818
  writeTimestamp(index: number): void {
819
    if (this.queryType !== 'inside-passes') {
820
      return;
821
    }
822

823
    // eslint-disable-next-line @typescript-eslint/no-explicit-any
824
    (this.computePassEncoder as any).writeTimestamp(this.querySet, index);
825
  }
826
  setQueryType(): void {
827
    this.queryType = 'none';
828
    if (
829
      this.env.webgpu.profiling?.mode === 'default' ||
830
      (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)
831
    ) {
832
      if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) {
833
        this.queryType = 'inside-passes';
834
      } else if (this.device.features.has('timestamp-query')) {
835
        this.queryType = 'at-passes';
836
      }
837

838
      if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
839
        this.querySet = this.device.createQuerySet({
840
          type: 'timestamp',
841
          count: this.maxDispatchNumber * 2,
842
        });
843
        this.queryResolveBuffer = this.device.createBuffer(
844
          // eslint-disable-next-line no-bitwise
845
          { size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE },
846
        );
847
      }
848
    }
849
  }
850

851
  captureBegin(): void {
852
    LOG_DEBUG('info', 'captureBegin');
853
    if (!this.capturedCommandList.get(this.currentSessionId!)) {
854
      this.capturedCommandList.set(this.currentSessionId!, []);
855
    }
856
    if (!this.capturedPendingKernels.get(this.currentSessionId!)) {
857
      this.capturedPendingKernels.set(this.currentSessionId!, []);
858
    }
859
    // flush the left commands before we change the status.
860
    this.flush();
861
    this.sessionStatus = 'capturing';
862
  }
863
  captureEnd(): void {
864
    LOG_DEBUG('info', 'captureEnd');
865
    // flush the left commands before we change the status.
866
    this.flush();
867
    this.sessionStatus = 'default';
868
  }
869
  replay(): void {
870
    LOG_DEBUG('info', 'replay');
871
    this.sessionStatus = 'replaying';
872
    const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
873
    const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
874
    const length = sessionCommandList!.length;
875
    this.pendingKernels = [];
876
    for (let i = 0; i < length; i++) {
877
      const computePassEncoder = this.getComputePassEncoder();
878
      const command = sessionCommandList![i];
879
      this.writeTimestamp(this.pendingDispatchNumber * 2);
880
      computePassEncoder.setPipeline(command.computePipeline);
881
      computePassEncoder.setBindGroup(0, command.bindGroup);
882
      computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
883
      this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
884
      this.pendingDispatchNumber++;
885
      if (this.queryType !== 'none') {
886
        this.pendingKernels.push(sessionPendingKernels![i]);
887
      }
888
      if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
889
        this.endComputePass();
890
      }
891
      if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
892
        this.flush();
893
      }
894
    }
895
    // flush the left commands before we change the status.
896
    this.flush();
897
    this.sessionStatus = 'default';
898
  }
899

900
  onReleaseSession(sessionId: number): void {
901
    this.unregisterBuffers(sessionId);
902
    if (this.capturedCommandList.has(sessionId)) {
903
      this.capturedCommandList.delete(sessionId);
904
    }
905
    if (this.capturedPendingKernels.has(sessionId)) {
906
      this.capturedPendingKernels.delete(sessionId);
907
    }
908
    this.gpuDataManager.onReleaseSession(sessionId);
909
  }
910

911
  onRunStart(sessionId: number): void {
912
    this.currentSessionId = sessionId;
913
    this.setQueryType();
914
  }
915
}
916

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.