onnxruntime

Форк
0
367 строк · 11.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { Tensor } from '../../../tensor';
7
import { MAX_CLIP, MIN_CLIP } from '../../../util';
8
import { FunctionType, GlslValueFunction } from '../glsl-definitions';
9
import { getGlsl } from '../glsl-source';
10
import { WebGLInferenceHandler } from '../inference-handler';
11
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
12

13
export function glslAbs(): GlslValueFunction {
14
  return glslBuiltinUnary('abs');
15
}
16
export function glslAcos(): GlslValueFunction {
17
  return glslBuiltinUnary('acos');
18
}
19
export function glslAsin(): GlslValueFunction {
20
  return glslBuiltinUnary('asin');
21
}
22
export function glslAtan(): GlslValueFunction {
23
  return glslBuiltinUnary('atan');
24
}
25
export function glslCeil(): GlslValueFunction {
26
  return glslBuiltinUnary('ceil');
27
}
28
export function glslCos(): GlslValueFunction {
29
  return glslBuiltinUnary('cos');
30
}
31
export function glslElu(alpha: number): GlslValueFunction {
32
  const name = 'elu';
33
  const body = `
34
  const float alpha = float(${alpha});
35

36
  float ${name}_(float a) {
37
    return a >= 0.0 ? a: (exp(a) - 1.0) * alpha;
38
  }
39
  vec4 ${name}_(vec4 v) {
40
    return vec4(${name}_(v.x), ${name}_(v.y), ${name}_(v.z), ${name}_(v.w));
41
  }
42
  `;
43
  return { body, name, type: FunctionType.ValueBased };
44
}
45
export function glslExp(): GlslValueFunction {
46
  return glslBuiltinUnary('exp');
47
}
48
export function glslFloor(): GlslValueFunction {
49
  return glslBuiltinUnary('floor');
50
}
51
export function glslClip(min: number, max: number): GlslValueFunction {
52
  const name = 'clip';
53
  const body = `
54
  const float min = float(${min});
55
  const float max = float(${max});
56

57
  float ${name}_(float a) {
58
    return clamp(a, min, max);
59
  }
60
  vec4 ${name}_(vec4 v) {
61
    return clamp(v, min, max);
62
  }
63
  `;
64
  return { body, name, type: FunctionType.ValueBased };
65
}
66
export function glslIdentity(): GlslValueFunction {
67
  const name = 'indentity';
68
  const body = `
69
  float ${name}_(float a) {
70
    return a;
71
  }
72
  vec4 ${name}_(vec4 v) {
73
    return v;
74
  }
75
  `;
76
  return { body, name, type: FunctionType.ValueBased };
77
}
78
export function glslLeakyRelu(alpha: number): GlslValueFunction {
79
  const name = 'leakyRelu';
80
  const body = `
81
  const float alpha = float(${alpha});
82

83
  float ${name}_(float a) {
84
    return a < 0.0 ? a * alpha : a;
85
  }
86
  vec4 ${name}_(vec4 v) {
87
    return vec4(${name}_(v.x), ${name}_(v.y), ${name}_(v.z), ${name}_(v.w));
88
  }
89
  `;
90
  return { body, name, type: FunctionType.ValueBased };
91
}
92
export function glslLog(): GlslValueFunction {
93
  return glslBuiltinUnary('log');
94
}
95
export function glslNeg(): GlslValueFunction {
96
  const name = 'neg';
97
  const body = `
98
  float ${name}_(float a) {
99
    return -a;
100
  }
101
  vec4 ${name}_(vec4 v) {
102
    return -v;
103
  }
104
  `;
105
  return { body, name, type: FunctionType.ValueBased };
106
}
107
export function glslNot(): GlslValueFunction {
108
  const name = 'not';
109
  const body = `
110
  float ${name}_(float a) {
111
    return float( ! bool(a) );
112
  }
113
  bool ${name}_(bool a) {
114
    return !a;
115
  }
116
  vec4 ${name}_(vec4 v) {
117
    return vec4(!bool(v.x), !bool(v.y), !bool(v.z), !bool(v.w));
118
  }
119
  bvec4 ${name}_(bvec4 v) {
120
    return bvec4(!v.x, !v.y, !v.z, !v.w);
121
  }
122
  `;
123
  return { body, name, type: FunctionType.ValueBased };
124
}
125
export function glslSin(): GlslValueFunction {
126
  return glslBuiltinUnary('sin');
127
}
128
export function glslRelu(): GlslValueFunction {
129
  const name = 'relu';
130
  const body = `
131
  float ${name}_(float a) {
132
    return max( a, 0.0 );
133
  }
134
  vec4 ${name}_(vec4 v) {
135
    return max( v, 0.0 );
136
  }
137
  `;
138
  return { body, name, type: FunctionType.ValueBased };
139
}
140
export function glslSigmoid(): GlslValueFunction {
141
  const name = 'sigmoid';
142
  const body = `
143
  float ${name}_(float a) {
144
    return 1.0 / (1.0 + exp(-a));
145
  }
146
  vec4 ${name}_(vec4 v) {
147
    return 1.0 / (1.0 + exp(-v));
148
  }
149
  `;
150
  return { body, name, type: FunctionType.ValueBased };
151
}
152
export function glslSqrt(): GlslValueFunction {
153
  return glslBuiltinUnary('sqrt');
154
}
155
export function glslTan(): GlslValueFunction {
156
  return glslBuiltinUnary('tan');
157
}
158
export function glslTanh(): GlslValueFunction {
159
  const name = 'tanh';
160
  const body = `
161
  float ${name}_(float a) {
162
    a = clamp(a, -10., 10.);
163
    a = exp(2.*a);
164
    return (a - 1.) / (a + 1.);
165
  }
166
  vec4 ${name}_(vec4 v) {
167
    v = clamp(v, -10., 10.);
168
    v = exp(2.*v);
169
    return (v - 1.) / (v + 1.);
170
  }
171
  `;
172
  return { body, name, type: FunctionType.ValueBased };
173
}
174
function glslBuiltinUnary(name: string): GlslValueFunction {
175
  const body = `
176
  float ${name}_(float a) {
177
    return ${name}(a);
178
  }
179
  vec4 ${name}_(vec4 v) {
180
    return ${name}(v);
181
  }
182
  `;
183
  return { body, name, type: FunctionType.ValueBased };
184
}
185

186
/////
187
/////
188
/////
189

190
const createElementwiseProgramInfo = (
191
  handler: WebGLInferenceHandler,
192
  metadata: ProgramMetadata,
193
  input: Tensor,
194
  glslFunc: GlslValueFunction,
195
): ProgramInfo => {
196
  const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
197
  const glsl = getGlsl(handler.session.backend.glContext.version);
198
  return {
199
    ...metadata,
200
    output: { dims: input.dims, type: input.type, textureType },
201
    shaderSource: `
202
     ${glslFunc.body}
203
     void main() {
204
       vec4 v = ${glsl.texture2D}(A, TexCoords);
205
       v = ${glslFunc.name}_(v);
206
       ${glsl.output} = v;
207
     }
208
     `,
209
    hasMain: true,
210
  };
211
};
212

213
const createElementwiseProgramInfoLoader = (
214
  handler: WebGLInferenceHandler,
215
  input: Tensor,
216
  glslFunc: GlslValueFunction,
217
  cacheKey?: string,
218
): ProgramInfoLoader => {
219
  const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
220
  const metadata = { name: glslFunc.name, inputTypes: [textureType], inputNames: ['A'], cacheHint: cacheKey };
221
  return { ...metadata, get: () => createElementwiseProgramInfo(handler, metadata, input, glslFunc) };
222
};
223

224
export const abs = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
225
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAbs()), inputs),
226
];
227

228
export const acos = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
229
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAcos()), inputs),
230
];
231

232
export const asin = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
233
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAsin()), inputs),
234
];
235

236
export const atan = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
237
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAtan()), inputs),
238
];
239

240
export interface ClipAttributes extends AttributeWithCacheKey {
241
  readonly min: number;
242
  readonly max: number;
243
}
244

245
export const clip = (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ClipAttributes): Tensor[] => [
246
  handler.run(
247
    createElementwiseProgramInfoLoader(
248
      handler,
249
      inputs[0],
250
      glslClip(attributes.min, attributes.max),
251
      attributes.cacheKey,
252
    ),
253
    inputs,
254
  ),
255
];
256

257
export const parseClipAttributes = (node: Graph.Node): ClipAttributes =>
258
  createAttributeWithCacheKey({
259
    min: node.attributes.getFloat('min', MIN_CLIP),
260
    max: node.attributes.getFloat('max', MAX_CLIP),
261
  });
262

263
export const clipV11 = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
264
  const attributes = generateClipAttributesFromInputs(handler, inputs);
265
  return clip(handler, [inputs[0]], attributes);
266
};
267

268
const generateClipAttributesFromInputs = (handler: WebGLInferenceHandler, inputs: Tensor[]): ClipAttributes => {
269
  if (
270
    inputs.length >= 3 &&
271
    (!handler.session.isInitializer(inputs[1].dataId) || !handler.session.isInitializer(inputs[2].dataId))
272
  ) {
273
    throw new Error('dynamic clip attributes are not allowed');
274
  }
275

276
  const min = inputs.length >= 3 ? inputs[1].numberData[0] : MIN_CLIP;
277
  const max = inputs.length >= 3 ? inputs[2].numberData[0] : MAX_CLIP;
278
  return createAttributeWithCacheKey({ min, max });
279
};
280

281
export const ceil = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
282
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCeil()), inputs),
283
];
284

285
export const cos = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
286
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCos()), inputs),
287
];
288

289
export interface EluAttributes extends AttributeWithCacheKey {
290
  readonly alpha: number;
291
}
292

293
export const elu = (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Tensor[] => [
294
  handler.run(
295
    createElementwiseProgramInfoLoader(handler, inputs[0], glslElu(attributes.alpha), attributes.cacheKey),
296
    inputs,
297
  ),
298
];
299

300
export const parseEluAttributes = (node: Graph.Node): EluAttributes =>
301
  createAttributeWithCacheKey({ alpha: node.attributes.getFloat('alpha', 1.0) });
302

303
export const exp = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
304
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslExp()), inputs),
305
];
306

307
export const floor = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
308
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslFloor()), inputs),
309
];
310

311
export const identity = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
312
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslIdentity()), inputs),
313
];
314

315
export interface LeakyReluAttributes extends AttributeWithCacheKey {
316
  readonly alpha: number;
317
}
318

319
export const leakyRelu = (
320
  handler: WebGLInferenceHandler,
321
  inputs: Tensor[],
322
  attributes: LeakyReluAttributes,
323
): Tensor[] => [
324
  handler.run(
325
    createElementwiseProgramInfoLoader(handler, inputs[0], glslLeakyRelu(attributes.alpha), attributes.cacheKey),
326
    inputs,
327
  ),
328
];
329

330
export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes =>
331
  createAttributeWithCacheKey({ alpha: node.attributes.getFloat('alpha', 0.01) });
332

333
export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
334
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs),
335
];
336

337
export const neg = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
338
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNeg()), inputs),
339
];
340

341
export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
342
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs),
343
];
344

345
export const relu = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
346
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslRelu()), inputs),
347
];
348

349
export const sigmoid = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
350
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSigmoid()), inputs),
351
];
352

353
export const sin = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
354
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSin()), inputs),
355
];
356

357
export const sqrt = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
358
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSqrt()), inputs),
359
];
360

361
export const tan = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
362
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTan()), inputs),
363
];
364

365
export const tanh = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
366
  handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTanh()), inputs),
367
];
368

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.