onnxruntime

Форк
0
/
program-manager.ts 
138 строк · 5.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { TRACE_FUNC_BEGIN, TRACE_FUNC_END } from 'onnxruntime-common';
5

6
import { WebGpuBackend } from '../backend-webgpu';
7
import { LOG_DEBUG } from '../log';
8

9
import { createShaderHelper } from './ops/common';
10
import { Artifact, GpuData, ProgramInfo } from './types';
11

12
/**
13
 * ProgramManager is the main class behind running computations
14
 * It builds ProgramInfo's into Artifacts
15
 * It compiles given ProgramInfo's into WebGL Prorams (cached as Artifacts)
16
 * Uses the artifact to run the computation by calling Draw on
17
 * the WebGL drawing buffer
18
 * ProgramManager automatically maps (binds) input variables to their
19
 * corresponding Location's in the binary program
20
 */
21
export class ProgramManager {
22
  repo: Map<unknown, Artifact>; // this should be per-session object
23
  attributesBound: boolean;
24

25
  constructor(private backend: WebGpuBackend) {
26
    this.repo = new Map();
27
    this.attributesBound = false;
28
  }
29
  getArtifact(key: unknown): Artifact | undefined {
30
    return this.repo.get(key);
31
  }
32
  setArtifact(key: unknown, artifact: Artifact): void {
33
    this.repo.set(key, artifact);
34
  }
35
  run(
36
    buildArtifact: Artifact,
37
    inputs: GpuData[],
38
    outputs: GpuData[],
39
    dispatchGroup: [number, number, number],
40
    uniformBufferBinding: GPUBindingResource | undefined,
41
  ): void {
42
    TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
43
    const device = this.backend.device;
44
    const computePassEncoder = this.backend.getComputePassEncoder();
45
    this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
46
    const entries = [];
47
    for (const input of inputs) {
48
      entries.push({ binding: entries.length, resource: { buffer: input.buffer } });
49
    }
50
    for (const output of outputs) {
51
      entries.push({ binding: entries.length, resource: { buffer: output.buffer } });
52
    }
53
    if (uniformBufferBinding) {
54
      entries.push({ binding: entries.length, resource: uniformBufferBinding });
55
    }
56
    const bindGroup = device.createBindGroup({
57
      layout: buildArtifact.computePipeline.getBindGroupLayout(0),
58
      entries,
59
      label: buildArtifact.programInfo.name,
60
    });
61

62
    if (this.backend.sessionStatus === 'capturing') {
63
      const commandInfo = {
64
        kernelId: this.backend.currentKernelId!,
65
        computePipeline: buildArtifact.computePipeline,
66
        bindGroup,
67
        dispatchGroup,
68
      };
69
      const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
70
      sessionCommandList!.push(commandInfo);
71
    }
72

73
    computePassEncoder.setPipeline(buildArtifact.computePipeline);
74
    computePassEncoder.setBindGroup(0, bindGroup);
75
    computePassEncoder.dispatchWorkgroups(...dispatchGroup);
76
    this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
77
    this.backend.pendingDispatchNumber++;
78

79
    if (
80
      this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber ||
81
      this.backend.queryType === 'at-passes'
82
    ) {
83
      this.backend.endComputePass();
84
    }
85
    if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) {
86
      this.backend.flush();
87
    }
88
    TRACE_FUNC_END(buildArtifact.programInfo.name);
89
  }
90
  dispose(): void {
91
    // this.repo.forEach(a => this.glContext.deleteProgram(a.program));
92
  }
93
  build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
94
    TRACE_FUNC_BEGIN(programInfo.name);
95
    const device = this.backend.device;
96
    const extensions: string[] = [];
97
    if (device.features.has('shader-f16')) {
98
      extensions.push('enable f16;');
99
    }
100
    const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits);
101
    const userCode = programInfo.getShaderSource(shaderHelper);
102
    const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
103
    const shaderModule = device.createShaderModule({ code, label: programInfo.name });
104
    LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`);
105

106
    const computePipeline = device.createComputePipeline({
107
      compute: { module: shaderModule, entryPoint: 'main' },
108
      layout: 'auto',
109
      label: programInfo.name,
110
    });
111

112
    TRACE_FUNC_END(programInfo.name);
113
    return { programInfo, computePipeline, uniformVariablesInfo: shaderHelper.variablesInfo };
114
  }
115

116
  normalizeDispatchGroupSize(
117
    dispatchGroup: ReturnType<ProgramInfo['getRunData']>['dispatchGroup'],
118
  ): [number, number, number] {
119
    const x = typeof dispatchGroup === 'number' ? dispatchGroup : dispatchGroup.x;
120
    const y = typeof dispatchGroup === 'number' ? 1 : dispatchGroup.y || 1;
121
    const z = typeof dispatchGroup === 'number' ? 1 : dispatchGroup.z || 1;
122
    const limitPerDimension = this.backend.device.limits.maxComputeWorkgroupsPerDimension;
123
    if (x <= limitPerDimension && y <= limitPerDimension && z <= limitPerDimension) {
124
      return [x, y, z];
125
    }
126
    const size = x * y * z;
127
    let dispatchAverage = Math.ceil(Math.sqrt(size));
128
    if (dispatchAverage > limitPerDimension) {
129
      dispatchAverage = Math.ceil(Math.cbrt(size));
130
      if (dispatchAverage > limitPerDimension) {
131
        throw new Error('Total dispatch size exceeds WebGPU maximum.');
132
      }
133
      return [dispatchAverage, dispatchAverage, dispatchAverage];
134
    } else {
135
      return [dispatchAverage, dispatchAverage, 1];
136
    }
137
  }
138
}
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.