onnxruntime

Форк
0
323 строки · 9.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { BroadcastUtil, ShapeUtil } from '../../../util';
6
import { FunctionType, GlslValueFunction } from '../glsl-definitions';
7
import { getGlsl } from '../glsl-source';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
10

11
export function glslAdd(): GlslValueFunction {
12
  const name = 'add_';
13
  const body = `
14
  float ${name}(float a, float b) {
15
    return a + b;
16
  }
17
  vec4 ${name}(vec4 v1, vec4 v2) {
18
    return v1 + v2;
19
  }
20
  `;
21
  return { body, name, type: FunctionType.ValueBased };
22
}
23
export function glslDiv(): GlslValueFunction {
24
  const name = 'div_';
25
  const body = `
26
  float ${name}(float a, float b) {
27
    return a / b;
28
  }
29
  vec4 ${name}(vec4 v1, vec4 v2) {
30
    return v1 / v2;
31
  }
32
  `;
33
  return { body, name, type: FunctionType.ValueBased };
34
}
35
export function glslMul(): GlslValueFunction {
36
  const name = 'mul_';
37
  const body = `
38
  float ${name}(float a, float b) {
39
    return a * b;
40
  }
41
  vec4 ${name}(vec4 v1, vec4 v2) {
42
    return v1 * v2;
43
  }
44
  `;
45
  return { body, name, type: FunctionType.ValueBased };
46
}
47
export function glslSub(): GlslValueFunction {
48
  const name = 'sub_';
49
  const body = `
50
  float ${name}(float a, float b) {
51
    return a - b;
52
  }
53
  vec4 ${name}(vec4 v1, vec4 v2) {
54
    return v1 - v2;
55
  }
56
  `;
57
  return { body, name, type: FunctionType.ValueBased };
58
}
59
export function glslEqual(): GlslValueFunction {
60
  const name = 'equal_';
61
  const body = `
62
  float ${name}(float a, float b) {
63
    return float(a == b);
64
  }
65
  vec4 ${name}(vec4 v1, vec4 v2) {
66
    return vec4(equal(v1, v2));
67
  }
68
  `;
69
  return { body, name, type: FunctionType.ValueBased };
70
}
71
export function glslGreater(): GlslValueFunction {
72
  const name = 'greater_';
73
  const body = `
74
  float ${name}(float a, float b) {
75
    return float(a > b);
76
  }
77
  vec4 ${name}(vec4 v1, vec4 v2) {
78
    return vec4( v1.r > v2.r ,
79
      v1.g > v2.g,
80
      v1.b > v2.b,
81
      v1.a > v2.a );
82
  }
83
  `;
84
  return { body, name, type: FunctionType.ValueBased };
85
}
86
export function glslLess(): GlslValueFunction {
87
  const name = 'less_';
88
  const body = `
89
  float ${name}(float a, float b) {
90
    return float(a < b);
91
  }
92
  vec4 ${name}(vec4 v1, vec4 v2) {
93
    return vec4( v1.r < v2.r ,
94
                v1.g < v2.g,
95
                v1.b < v2.b,
96
                v1.a < v2.a );
97
  }
98
  `;
99
  return { body, name, type: FunctionType.ValueBased };
100
}
101
export function glslAnd(): GlslValueFunction {
102
  const name = 'and_';
103
  const body = `
104
  float ${name}(float a, float b) {
105
    return float( bool(a) && bool(b) );
106
  }
107
  vec4 ${name}(vec4 v1, vec4 v2) {
108
    bvec4 b1 = bvec4(v1);
109
    bvec4 b2 = bvec4(v2);
110
    return vec4( b1.r && b2.r ,
111
                b1.g && b2.g,
112
                b1.b && b2.b,
113
                b1.a && b2.a );
114
  }
115
  `;
116
  return { body, name, type: FunctionType.ValueBased };
117
}
118
export function glslOr(): GlslValueFunction {
119
  const name = 'or_';
120
  const body = `
121
  float ${name}(float a, float b) {
122
    return float( bool(a) || bool(b) );
123
  }
124
  vec4 ${name}(vec4 v1, vec4 v2) {
125
    bvec4 b1 = bvec4(v1);
126
    bvec4 b2 = bvec4(v2);
127
    return vec4( b1.r || b2.r ,
128
                b1.g || b2.g,
129
                b1.b || b2.b,
130
                b1.a || b2.a );
131
  }
132
  `;
133
  return { body, name, type: FunctionType.ValueBased };
134
}
135
export function glslXor(): GlslValueFunction {
136
  const name = 'xor_';
137
  const body = `
138
  float ${name}(float a, float b) {
139
    return float( bool(a) ^^ bool(b) );
140
  }
141
  vec4 ${name}(vec4 v1, vec4 v2) {
142
    bvec4 b1 = bvec4(v1);
143
    bvec4 b2 = bvec4(v2);
144
    return vec4( b1.r ^^ b2.r ,
145
                b1.g ^^ b2.g,
146
                b1.b ^^ b2.b,
147
                b1.a ^^ b2.a );
148
  }
149
  `;
150
  return { body, name, type: FunctionType.ValueBased };
151
}
152
export function glslPow(): GlslValueFunction {
153
  return glslBuiltinBinary('pow');
154
}
155
export function glslPRelu(): GlslValueFunction {
156
  const name = 'prelu_';
157
  const body = `
158
  float ${name}(float a, float b) {
159
    return a < 0.0 ? a * b: a;
160
  }
161
  vec4 ${name}(vec4 v1, vec4 v2) {
162
    return vec4(
163
      v1.r < 0.0 ? v1.r * v2.r: v1.r,
164
      v1.g < 0.0 ? v1.g * v2.g: v1.g,
165
      v1.b < 0.0 ? v1.b * v2.b: v1.b,
166
      v1.a < 0.0 ? v1.a * v2.a: v1.a
167
      );
168
  }
169
  `;
170
  return { body, name, type: FunctionType.ValueBased };
171
}
172

173
function glslBuiltinBinary(fname: string): GlslValueFunction {
174
  const name = `${fname}_`;
175
  const body = `
176
  float ${name}(float a, float b) {
177
    return ${fname}(a, b);
178
  }
179
  vec4 ${name}(vec4 v1, vec4 v2) {
180
    return ${fname}(v1, v2);
181
  }
182
  `;
183
  return { body, name, type: FunctionType.ValueBased };
184
}
185

186
const createBinaryProgramInfoLoader = (
187
  handler: WebGLInferenceHandler,
188
  inputs: Tensor[],
189
  glslFunc: GlslValueFunction,
190
  outputTensorType: Tensor.DataType = inputs[0].type,
191
  cacheKey?: string,
192
): ProgramInfoLoader => {
193
  const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
194
  return {
195
    name: glslFunc.name,
196
    inputNames: ['A', 'B'],
197
    inputTypes: [textureType, textureType],
198
    cacheHint: cacheKey,
199
    get: () => createBinaryProgramInfo(handler, inputs, glslFunc, outputTensorType),
200
  };
201
};
202

203
const createBinaryProgramInfo = (
204
  handler: WebGLInferenceHandler,
205
  inputs: Tensor[],
206
  glslFunc: GlslValueFunction,
207
  outputTensorType: Tensor.DataType = inputs[0].type,
208
): ProgramInfo => {
209
  const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
210
  const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
211
  let outputShape = inputs[0].dims;
212

213
  const usePackedTexture = handler.session.pack;
214

215
  if (isBroadcast) {
216
    const calculatedShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false);
217
    if (!calculatedShape) {
218
      throw new Error("Can't perform binary op on the given tensors");
219
    }
220
    outputShape = calculatedShape;
221
    const outputRank = outputShape.length;
222
    const aRank = inputs[0].dims.length !== 0 ? inputs[0].dims.length : 1;
223
    const bRank = inputs[1].dims.length !== 0 ? inputs[1].dims.length : 1;
224
    const aBcast = inputs[0].dims.length !== 0 ? 'bcastIndices_A(indices, aindices);' : 'aindices[0] = 0;';
225
    const bBcast = inputs[1].dims.length !== 0 ? 'bcastIndices_B(indices, bindices);' : 'bindices[0] = 0;';
226

227
    const glsl = getGlsl(handler.session.backend.glContext.version);
228
    const shaderSource = usePackedTexture
229
      ? `
230
      ${glslFunc.body}
231
      void main() {
232
        vec4 a = getAAtOutCoords();
233
        vec4 b = getBAtOutCoords();
234
        vec4 result = ${glslFunc.name}(a, b);
235
        ${glsl.output} = result;
236
      }`
237
      : `
238
      ${glslFunc.body}
239
      float process(int indices[${outputRank}]) {
240
        int aindices[${aRank}];
241
        int bindices[${bRank}];
242
        ${aBcast}
243
        ${bBcast}
244
        return ${glslFunc.name}(_A(aindices), _B(bindices));
245
      }`;
246

247
    return {
248
      name: glslFunc.name,
249
      inputNames: ['A', 'B'],
250
      inputTypes: [textureType, textureType],
251
      output: { dims: outputShape, type: outputTensorType, textureType },
252
      shaderSource,
253
      hasMain: usePackedTexture,
254
    };
255
  }
256
  const glsl = getGlsl(handler.session.backend.glContext.version);
257
  const shaderSource = `
258
    ${glslFunc.body}
259
    void main() {
260
      vec4 v1 = ${glsl.texture2D}(A, TexCoords);
261
      vec4 v2 = ${glsl.texture2D}(B, TexCoords);
262
      vec4 result = ${glslFunc.name}(v1, v2);
263
      ${glsl.output} = result;
264
    }
265
    `;
266

267
  return {
268
    name: glslFunc.name,
269
    inputNames: ['A', 'B'],
270
    inputTypes: [textureType, textureType],
271
    output: { dims: inputs[0].dims, type: outputTensorType, textureType },
272
    shaderSource,
273
    hasMain: true,
274
  };
275
};
276

277
export const add = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
278
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAdd()), inputs),
279
];
280

281
export const and = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
282
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAnd(), 'bool'), inputs),
283
];
284

285
export const div = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
286
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslDiv()), inputs),
287
];
288

289
export const equal = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
290
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslEqual(), 'bool'), inputs),
291
];
292

293
export const greater = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
294
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslGreater(), 'bool'), inputs),
295
];
296

297
export const less = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
298
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslLess(), 'bool'), inputs),
299
];
300

301
export const mul = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
302
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslMul()), inputs),
303
];
304

305
export const or = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
306
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslOr(), 'bool'), inputs),
307
];
308

309
export const pow = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
310
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPow()), inputs),
311
];
312

313
export const pRelu = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
314
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPRelu()), inputs),
315
];
316

317
export const sub = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
318
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslSub()), inputs),
319
];
320

321
export const xor = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
322
  handler.run(createBinaryProgramInfoLoader(handler, inputs, glslXor(), 'bool'), inputs),
323
];
324

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.