onnxruntime

Форк
0
73 строки · 2.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
8

9
export const sum = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
10
  validateInputs(inputs);
11

12
  const sumProgramMetadata = {
13
    name: 'Sum',
14
    inputNames: inputs.map((_v, i) => `X${i}`),
15
    inputTypes: new Array(inputs.length).fill(TextureType.unpacked),
16
  };
17

18
  const output = inferenceHandler.run(
19
    { ...sumProgramMetadata, get: () => createSumProgramInfo(inferenceHandler, inputs, sumProgramMetadata) },
20
    inputs,
21
  );
22
  return [output];
23
};
24

25
const createSumProgramInfo = (
26
  inferenceHandler: WebGLInferenceHandler,
27
  inputs: Tensor[],
28
  sumProgramMetadata: ProgramMetadata,
29
): ProgramInfo => {
30
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
31
  const outputShape = inputs[0].dims.slice();
32
  const sumLine = inputs.map((_v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + ');
33
  const shaderSource = `
34
      void main() {
35
        vec4 result = ${sumLine};
36
        ${glsl.output} = result;
37
      }
38
    `;
39
  return {
40
    ...sumProgramMetadata,
41
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
42
    hasMain: true,
43
    shaderSource,
44
  };
45
};
46

47
const validateInputs = (inputs: Tensor[]): void => {
48
  if (!inputs || inputs.length === 0) {
49
    throw new Error('Sum requires inputs.');
50
  }
51

52
  const length = inputs[0].dims.length;
53
  for (let i = 1; i < inputs.length; i++) {
54
    if (length !== inputs[i].dims.length) {
55
      throw new Error('Input shapes are mismatched.');
56
    }
57

58
    for (let j = 0; j < length; j++) {
59
      if (inputs[0].dims[j] !== inputs[i].dims[j]) {
60
        throw new Error('Input shapes are not matched.');
61
      }
62
    }
63
  }
64

65
  if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
66
    throw new Error('Invalid input type.');
67
  }
68
  for (let i = 1; i < inputs.length; i++) {
69
    if (inputs[0].type !== inputs[i].type) {
70
      throw new Error('Input types are not matched.');
71
    }
72
  }
73
};
74

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.