onnxruntime

Форк
0
/
program-manager.ts 
205 строк · 7.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { env } from 'onnxruntime-common';
5

6
import { Logger, Profiler } from '../../instrument';
7

8
import { GlslPreprocessor } from './glsl-preprocessor';
9
import { getVertexShaderSource } from './glsl-source';
10
import { TextureLayoutStrategy } from './texture-layout-strategy';
11
import { Artifact, ProgramInfo, ProgramVariable, TextureData, TextureLayout, VariableInfo } from './types';
12
import { WebGLContext } from './webgl-context';
13

14
/**
15
 * ProgramManager is the main class behind running computations
16
 * It builds ProgramInfo's into Artifacts
17
 * It compiles given ProgramInfo's into WebGL Prorams (cached as Artifacts)
18
 * Uses the artifact to run the computation by calling Draw on
19
 * the WebGL drawing buffer
20
 * ProgramManager automatically maps (binds) input variables to their
21
 * corresponding Location's in the binary program
22
 */
23
export class ProgramManager {
24
  repo: Map<unknown, Artifact>; // this should be per-session object
25
  vertexShader: WebGLShader;
26
  attributesBound: boolean;
27

28
  constructor(
29
    public profiler: Readonly<Profiler>,
30
    public glContext: WebGLContext,
31
    public textureLayoutStrategy: TextureLayoutStrategy,
32
  ) {
33
    this.repo = new Map();
34
    this.attributesBound = false;
35
  }
36
  getArtifact(key: unknown): Artifact | undefined {
37
    return this.repo.get(key);
38
  }
39
  setArtifact(key: unknown, artifact: Artifact): void {
40
    this.repo.set(key, artifact);
41
  }
42
  run(buildArtifact: Artifact, inputs: TextureData[], output: TextureData): void {
43
    this.profiler.event(
44
      'op',
45
      `ProgramManager.run ${buildArtifact.programInfo.name ?? 'unknown kernel'}`,
46
      () => {
47
        const gl = this.glContext.gl;
48
        const program = buildArtifact.program;
49
        gl.useProgram(program);
50
        try {
51
          this.bindOutput(output);
52
          if (!this.attributesBound) {
53
            this.bindAttributes(buildArtifact.attribLocations);
54
          }
55
          this.bindUniforms(buildArtifact.uniformLocations, buildArtifact.programInfo.variables ?? [], inputs);
56
        } catch (err) {
57
          Logger.error('ProgramManager', buildArtifact.programInfo.shaderSource);
58
          throw err;
59
        }
60
        this.profiler.event('backend', 'GlContext.draw()', () => {
61
          this.glContext.draw();
62
        });
63
      },
64
      this.glContext,
65
    );
66
  }
67
  dispose(): void {
68
    if (this.vertexShader) {
69
      this.glContext.deleteShader(this.vertexShader);
70
    }
71
    this.repo.forEach((a) => this.glContext.deleteProgram(a.program));
72
  }
73
  build(programInfo: ProgramInfo, inputTextureLayouts: TextureLayout[], outputTextureLayout: TextureLayout): Artifact {
74
    return this.profiler.event('backend', 'ProgramManager.build', () => {
75
      const preprocessor = new GlslPreprocessor(this.glContext, programInfo, inputTextureLayouts, outputTextureLayout);
76
      const fragScript = preprocessor.preprocess();
77
      const program = this.compile(fragScript);
78
      const artifact = {
79
        programInfo,
80
        program,
81
        uniformLocations: this.getUniformLocations(
82
          program,
83
          preprocessor.context.programInfo.inputNames,
84
          preprocessor.context.programInfo.variables,
85
        ),
86
        attribLocations: this.getAttribLocations(program),
87
      };
88
      return artifact;
89
    });
90
  }
91
  protected compile(fragShaderScript: string): WebGLProgram {
92
    if (!this.vertexShader) {
93
      Logger.verbose('ProrgramManager', 'Compiling and caching Vertex shader for the first time');
94
      const vertexShaderScript = getVertexShaderSource(this.glContext.version);
95
      this.vertexShader = this.glContext.compileShader(vertexShaderScript, this.glContext.gl.VERTEX_SHADER);
96
    }
97
    if (env.debug) {
98
      Logger.verbose(
99
        'ProrgramManager',
100
        `FragShader:
101
${fragShaderScript}
102
`,
103
      );
104
    }
105
    const fragShader = this.glContext.compileShader(fragShaderScript, this.glContext.gl.FRAGMENT_SHADER);
106
    const program = this.glContext.createProgram(this.vertexShader, fragShader);
107
    this.glContext.deleteShader(fragShader);
108
    return program;
109
  }
110
  bindOutput(td: TextureData): void {
111
    const width = td.width;
112
    const height = td.height;
113
    Logger.verbose(
114
      'ProrgramManager',
115
      `Binding output texture to Framebuffer: w/h=${width}/${height}, shape=${td.shape}, type=${td.tensor.type}`,
116
    );
117
    this.glContext.attachFramebuffer(td.texture, width, height);
118
  }
119
  bindAttributes(attribLocations: Artifact.AttribLocations): void {
120
    const positionHandle = attribLocations.position;
121
    const textureCoordHandle = attribLocations.textureCoord;
122
    this.glContext.setVertexAttributes(positionHandle, textureCoordHandle);
123
    this.attributesBound = true;
124
  }
125
  bindUniforms(
126
    uniformLocations: Artifact.UniformLocations,
127
    variables: ProgramVariable[],
128
    textures: TextureData[],
129
  ): void {
130
    const gl = this.glContext.gl;
131
    let texturePosition = 0;
132
    for (const { name, type, location, arrayLength } of uniformLocations) {
133
      const value = variables.find((v) => v.name === name)?.data;
134
      if (type !== 'sampler2D' && !value) {
135
        throw new Error(`variable '${name}' does not have data defined in program info`);
136
      }
137
      switch (type) {
138
        case 'sampler2D':
139
          this.bindTexture(textures[texturePosition], location, texturePosition);
140
          texturePosition++;
141
          break;
142
        case 'float':
143
          if (arrayLength) {
144
            gl.uniform1fv(location, value as number[]);
145
          } else {
146
            gl.uniform1f(location, value as number);
147
          }
148
          break;
149
        case 'int':
150
          if (arrayLength) {
151
            gl.uniform1iv(location, value as number[]);
152
          } else {
153
            gl.uniform1i(location, value as number);
154
          }
155
          break;
156
        default:
157
          throw new Error(`Uniform not implemented: ${type}`);
158
      }
159
    }
160
  }
161
  bindTexture(td: TextureData, uniformHandle: WebGLUniformLocation, position: number): void {
162
    this.glContext.bindTextureToUniform(td.texture, position, uniformHandle);
163
  }
164
  getAttribLocations(program: WebGLProgram): Artifact.AttribLocations {
165
    return {
166
      position: this.getAttribLocation(program, 'position'),
167
      textureCoord: this.getAttribLocation(program, 'textureCoord'),
168
    };
169
  }
170
  getUniformLocations(
171
    program: WebGLProgram,
172
    samplers?: string[],
173
    variables?: VariableInfo[],
174
  ): Artifact.UniformLocations {
175
    const uniformLocations: Artifact.UniformLocations = [];
176
    if (samplers) {
177
      for (const sampler of samplers) {
178
        uniformLocations.push({
179
          name: sampler,
180
          type: 'sampler2D',
181
          location: this.getUniformLocation(program, sampler),
182
        });
183
      }
184
    }
185
    if (variables) {
186
      for (const variable of variables) {
187
        uniformLocations.push({ ...variable, location: this.getUniformLocation(program, variable.name) });
188
      }
189
    }
190
    return uniformLocations;
191
  }
192
  getUniformLocation(program: WebGLProgram, name: string): WebGLUniformLocation {
193
    const gl = this.glContext.gl;
194
    const reference = gl.getUniformLocation(program, name);
195
    if (reference === null) {
196
      throw new Error(`Uniform ${name} not found.`);
197
    }
198
    return reference;
199
  }
200
  getAttribLocation(program: WebGLProgram, name: string): number {
201
    const gl = this.glContext.gl;
202
    const attributeLocation: number = gl.getAttribLocation(program, name);
203
    return attributeLocation;
204
  }
205
}
206

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.