onnxruntime

Форк
0
461 строка · 14.8 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
9

10
import {
11
  inputVariable,
12
  outputVariable,
13
  ShaderHelper,
14
  tensorTypeToWsglValueType,
15
  UniformDataElementType,
16
  UniformsArrayType,
17
} from './common';
18

19
type BuiltinFunctionName = string;
20
type ElementwiseCustomExpression = (expression: string) => string;
21
type ElementwiseFunctionCall = BuiltinFunctionName | ElementwiseCustomExpression;
22

23
const createElementwiseProgramShader = (
24
  shaderHelper: ShaderHelper,
25
  datasize: number,
26
  inputDataType: number,
27
  outputDataType: number,
28
  funcCall: ElementwiseFunctionCall,
29
  additionalImplementation?: string,
30
  additionalUniformsType?: UniformsArrayType,
31
): string => {
32
  const vecSize = Math.ceil(datasize / 4);
33

34
  let expression = '';
35
  if (typeof funcCall === 'string') {
36
    expression = `${funcCall}(a)`;
37
  } else {
38
    expression = funcCall('a');
39
  }
40

41
  const input = inputVariable('inputData', inputDataType, [vecSize], 4);
42
  const output = outputVariable('outputData', outputDataType, [vecSize], 4);
43
  const uniforms: UniformsArrayType = [{ name: 'vec_size', type: 'u32' }];
44
  if (additionalUniformsType) {
45
    uniforms.push(...additionalUniformsType);
46
  }
47

48
  return `
49
      ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
50

51
  ${additionalImplementation ?? ''}
52

53
  ${shaderHelper.mainStart()}
54
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
55

56
    let a = ${input.getByOffset('global_idx')};
57
    ${output.setByOffset('global_idx', expression)}
58
  }`;
59
};
60

61
const createElementwiseProgramInfo = (
62
  input: TensorView,
63
  name: string,
64
  funcCall: ElementwiseFunctionCall,
65
  additionalImplementation?: string,
66
  cacheKey?: string,
67
  outputDataType: number = input.dataType,
68
  additionalUniforms?: ProgramUniform[],
69
  additionalUniformsType?: UniformsArrayType,
70
): ProgramInfo => {
71
  const programUniforms: ProgramUniform[] = [
72
    { type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) },
73
  ];
74
  if (additionalUniforms) {
75
    programUniforms.push(...additionalUniforms);
76
  }
77

78
  return {
79
    name,
80
    shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
81
    getShaderSource: (shaderHelper) =>
82
      createElementwiseProgramShader(
83
        shaderHelper,
84
        ShapeUtil.size(input.dims),
85
        input.dataType,
86
        outputDataType,
87
        funcCall,
88
        additionalImplementation,
89
        additionalUniformsType,
90
      ),
91
    getRunData: (inputTensors) => ({
92
      outputs: [{ dims: input.dims, dataType: outputDataType }],
93
      dispatchGroup: {
94
        x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */),
95
      },
96
      programUniforms,
97
    }),
98
  };
99
};
100

101
export const abs = (context: ComputeContext): void => {
102
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs'));
103
};
104

105
export const acos = (context: ComputeContext): void => {
106
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Acos', 'acos'));
107
};
108

109
export const acosh = (context: ComputeContext): void => {
110
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Acosh', 'acosh'));
111
};
112

113
export const asin = (context: ComputeContext): void => {
114
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Asin', 'asin'));
115
};
116

117
export const asinh = (context: ComputeContext): void => {
118
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Asinh', 'asinh'));
119
};
120

121
export const atan = (context: ComputeContext): void => {
122
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Atan', 'atan'));
123
};
124
export const atanh = (context: ComputeContext): void => {
125
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Atanh', 'atanh'));
126
};
127

128
export interface CastAttributes extends AttributeWithCacheKey {
129
  readonly to: number;
130
  readonly saturate?: boolean;
131
}
132

133
export const parseCastAttributes = (attributes: Record<string, unknown>): CastAttributes =>
134
  createAttributeWithCacheKey(attributes as { to: number });
135

136
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
137
  let func: ElementwiseFunctionCall;
138
  switch (attributes.to) {
139
    case DataType.float16:
140
      func = 'vec4<f16>';
141
      break;
142
    case DataType.float:
143
      func = 'vec4<f32>';
144
      break;
145
    case DataType.uint32:
146
      func = 'vec4<u32>';
147
      break;
148
    case DataType.int32:
149
      func = 'vec4<i32>';
150
      break;
151
    case DataType.bool:
152
      func = 'vec4<bool>';
153
      break;
154
    default:
155
      throw new RangeError(`not supported type (specified in attribute 'to' from 'Cast' operator): ${attributes.to}`);
156
  }
157
  context.compute(
158
    createElementwiseProgramInfo(context.inputs[0], 'Cast', func, undefined, attributes.cacheKey, attributes.to),
159
  );
160
};
161

162
export interface ClipAttributes extends AttributeWithCacheKey {
163
  readonly min: number;
164
  readonly max: number;
165
}
166

167
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
168
  let min: number;
169
  let max: number;
170
  const hasMin = inputs.length >= 2 && inputs[1].data !== 0;
171
  const hasMax = inputs.length >= 3 && inputs[2].data !== 0;
172

173
  switch (inputs[0].dataType) {
174
    case DataType.float:
175
      min = hasMin ? inputs[1].getFloat32Array()[0] : -3.4028234663852886e38;
176
      max = hasMax ? inputs[2].getFloat32Array()[0] : 3.4028234663852886e38;
177
      break;
178
    case DataType.float16:
179
      min = hasMin ? inputs[1].getUint16Array()[0] : 64511; // uint16(64511) <-> float16(-65504.0)
180
      max = hasMax ? inputs[2].getUint16Array()[0] : 31743; // uint16(31743) <-> float16(65504.0)
181
      break;
182
    default:
183
      throw new Error('Unsupport data type');
184
  }
185

186
  return createAttributeWithCacheKey({ min, max });
187
};
188

189
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
190
  const attributes = clipAttributes ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
191
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
192
  context.compute(
193
    createElementwiseProgramInfo(
194
      context.inputs[0],
195
      'Clip',
196
      (a) => `clamp(${a}, vec4<${dataType}>(uniforms.min), vec4<${dataType}>(uniforms.max))`,
197
      undefined,
198
      attributes.cacheKey,
199
      undefined,
200
      [
201
        { type: context.inputs[0].dataType, data: attributes.min },
202
        { type: context.inputs[0].dataType, data: attributes.max },
203
      ],
204
      [
205
        { name: 'min', type: dataType as UniformDataElementType },
206
        { name: 'max', type: dataType as UniformDataElementType },
207
      ],
208
    ),
209
    { inputs: [0] },
210
  );
211
};
212

213
export const ceil = (context: ComputeContext): void => {
214
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));
215
};
216

217
export const cos = (context: ComputeContext): void => {
218
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Cos', 'cos'));
219
};
220

221
export const cosh = (context: ComputeContext): void => {
222
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Cosh', 'cosh'));
223
};
224

225
export interface AlphaAttributes extends AttributeWithCacheKey {
226
  readonly alpha: number;
227
}
228

229
export const parseAlphaAttributes = (attributes: Record<string, unknown>): AlphaAttributes =>
230
  createAttributeWithCacheKey(attributes as { alpha: number });
231

232
export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
233
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
234
  context.compute(
235
    createElementwiseProgramInfo(
236
      context.inputs[0],
237
      'Elu',
238
      (a) => `elu_vf32(${a})`,
239
      `
240
  const elu_alpha_ = ${dataType}(${attributes.alpha});
241

242
  fn elu_f32(a: ${dataType}) -> ${dataType} {
243
  return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
244
  }
245

246
  fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> {
247
  return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
248
  }`,
249
      attributes.cacheKey,
250
    ),
251
  );
252
};
253

254
export const erfImpl = (varType = 'f32') => `
255
const r0: ${varType} = 0.3275911;
256
const r1: ${varType} = 0.254829592;
257
const r2: ${varType} = -0.284496736;
258
const r3: ${varType} = 1.421413741;
259
const r4: ${varType} = -1.453152027;
260
const r5: ${varType} = 1.061405429;
261

262
fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> {
263
  let absv = abs(v);
264
  let x = 1.0 / (1.0 + r0 * absv);
265
  return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
266
}`;
267

268
export const erf = (context: ComputeContext): void => {
269
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
270
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', (a) => `erf_vf32(${a})`, erfImpl(dataType)));
271
};
272

273
export const exp = (context: ComputeContext): void => {
274
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Exp', 'exp'));
275
};
276

277
export const floor = (context: ComputeContext): void => {
278
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Floor', 'floor'));
279
};
280

281
export const gelu = (context: ComputeContext): void => {
282
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
283
  context.compute(
284
    createElementwiseProgramInfo(
285
      context.inputs[0],
286
      'Gelu',
287
      (a) => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
288
      erfImpl(dataType),
289
    ),
290
  );
291
};
292

293
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
294
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
295
  context.compute(
296
    createElementwiseProgramInfo(
297
      context.inputs[0],
298
      'LeakyRelu',
299
      (a) => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`,
300
      `const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`,
301
      attributes.cacheKey,
302
    ),
303
  );
304
};
305

306
export const not = (context: ComputeContext): void => {
307
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Not', (a) => `!${a}`));
308
};
309

310
export const neg = (context: ComputeContext): void => {
311
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Neg', (a) => `-${a}`));
312
};
313

314
export const reciprocal = (context: ComputeContext): void => {
315
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Reciprocal', (a) => `1.0/${a}`));
316
};
317

318
export const relu = (context: ComputeContext): void => {
319
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
320
  context.compute(
321
    createElementwiseProgramInfo(
322
      context.inputs[0],
323
      'Relu',
324
      (a) => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`,
325
    ),
326
  );
327
};
328

329
export const sigmoid = (context: ComputeContext): void => {
330
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', (a) => `(1.0 / (1.0 + exp(-${a})))`));
331
};
332

333
export interface HardSigmoidAttributes extends AttributeWithCacheKey {
334
  readonly alpha: number;
335
  readonly beta: number;
336
}
337

338
export const parseHardSigmoidAttributes = (attributes: Record<string, unknown>): HardSigmoidAttributes =>
339
  createAttributeWithCacheKey(
340
    attributes as {
341
      alpha: number;
342
      beta: number;
343
    },
344
  );
345

346
export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
347
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
348
  context.compute(
349
    createElementwiseProgramInfo(
350
      context.inputs[0],
351
      'HardSigmoid',
352
      (a) =>
353
        `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${attributes.beta})))`,
354
      undefined,
355
      attributes.cacheKey,
356
    ),
357
  );
358
};
359

360
export const sin = (context: ComputeContext): void => {
361
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
362
};
363

364
export const sinh = (context: ComputeContext): void => {
365
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sinh', 'sinh'));
366
};
367

368
export const sqrt = (context: ComputeContext): void => {
369
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sqrt', 'sqrt'));
370
};
371

372
export const tan = (context: ComputeContext): void => {
373
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan'));
374
};
375

376
export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`;
377

378
export const tanh = (context: ComputeContext): void => {
379
  // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
380
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression));
381
};
382

383
export const fastGeluImpl = (varType = 'f32') => `
384
const fast_gelu_a: ${varType} = 0.5;
385
const fast_gelu_b: ${varType} = 0.7978845608028654;
386
const fast_gelu_c: ${varType} = 0.035677408136300125;
387

388
fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> {
389
  return ${tanhExpression('v')};
390
}
391
`;
392

393
export const fastGeluExpression = (x: string) =>
394
  `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`;
395

396
export const fastGelu = (context: ComputeContext): void => {
397
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
398
  context.compute(
399
    createElementwiseProgramInfo(
400
      context.inputs[0],
401
      'FastGelu',
402
      fastGeluExpression,
403
      fastGeluImpl(dataType),
404
      undefined,
405
      context.inputs[0].dataType,
406
    ),
407
  );
408
};
409

410
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
411
  const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
412
  context.compute(
413
    createElementwiseProgramInfo(
414
      context.inputs[0],
415
      'ThresholdedRelu',
416
      (a) => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
417
      `const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`,
418
      attributes.cacheKey,
419
    ),
420
  );
421
  return 0;
422
};
423

424
export const log = (context: ComputeContext): void => {
425
  context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log'));
426
};
427

428
export const quickGeluImpl = (varType: string, alpha: number) => `
429
const alpha = vec4<${varType}>(${alpha});
430
const one = ${varType}(1.0);
431
const zero = ${varType}(0.0);
432

433
fn quick_gelu_impl(x: vec4<${varType}>) -> vec4<${varType}> {
434
  let v = x *alpha;
435
  var x1 : vec4<${varType}>;
436
  for (var i = 0; i < 4; i = i + 1) {
437
    if (v[i] >= zero) {
438
      x1[i] = one / (one + exp(-v[i]));
439
    } else {
440
      x1[i] = one - one / (one + exp(v[i]));
441
    }
442
  }
443
  return x * x1;
444
}
445
`;
446

447
export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`;
448

449
export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
450
  const dType = tensorTypeToWsglValueType(context.inputs[0].dataType);
451
  context.compute(
452
    createElementwiseProgramInfo(
453
      context.inputs[0],
454
      'QuickGelu',
455
      quickGeluExpression,
456
      quickGeluImpl(dType, attributes.alpha),
457
      attributes.cacheKey,
458
      context.inputs[0].dataType,
459
    ),
460
  );
461
};
462

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.