onnxruntime

Форк
0
250 строк · 8.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
8

9
import {
10
  createTensorShapeVariables,
11
  getElementAt,
12
  IndicesHelper,
13
  inputVariable,
14
  outputVariable,
15
  ShaderHelper,
16
  UniformDataElementType,
17
  UniformsArrayType,
18
} from './common';
19

20
interface PadAttributes {
21
  // 0-constant, 1-reflect, 2-edge, 3-wrap
22
  readonly mode: number;
23
  readonly value: number;
24
  readonly pads: number[];
25
}
26

27
const validateInputs = (inputs: readonly TensorView[]): void => {
28
  if (!inputs || inputs.length < 1) {
29
    throw new Error('Too few inputs');
30
  }
31
  if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) {
32
    throw new Error('Input type must be float or float16.');
33
  }
34

35
  if (inputs.length >= 2) {
36
    let validPads = inputs[0].dims.length * 2 === inputs[1].dims[0];
37
    if (inputs.length === 4) {
38
      validPads = inputs[3].dims[0] * 2 === inputs[1].dims[0];
39
    }
40
    if (!validPads) {
41
      throw new Error('The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].');
42
    }
43
  }
44
};
45

46
const getPadConstant = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
47
  let block = '';
48
  for (let i = inputRank - 1; i >= 0; --i) {
49
    block += `
50
            k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
51
            if (k < 0) {
52
              break;
53
            }
54
            if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
55
              break;
56
            }
57
            offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
58
        `;
59
  }
60

61
  return `
62
          value = ${output.type.value}(uniforms.constant_value);
63
          for (var i = 0; i < 1; i++) {
64
            var offset = 0;
65
            var k = 0;
66
            ${block}
67
            value = x[offset];
68
          }
69
      `;
70
};
71

72
const getPadReflect = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
73
  let block = '';
74
  for (let i = inputRank - 1; i >= 0; --i) {
75
    block += `
76
                k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
77
                if (k < 0) {
78
                  k = -k;
79
                }
80
                {
81
                  let _2n_1 = 2 * (i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1);
82
                  k = k % _2n_1;
83
                  if(k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
84
                    k = _2n_1 - k;
85
                  }
86
                }
87
                offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
88
            `;
89
  }
90

91
  return `
92
              var offset = 0;
93
              var k = 0;
94
              ${block}
95
              value = x[offset];
96
          `;
97
};
98

99
const getPadEdge = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
100
  let block = '';
101
  for (let i = inputRank - 1; i >= 0; --i) {
102
    block += `
103
                k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
104
                if (k < 0) {
105
                  k = 0;
106
                }
107
                if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
108
                  k = i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1;
109
                }
110
                offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
111
            `;
112
  }
113

114
  return `
115
              var offset = 0;
116
              var k = 0;
117
              ${block}
118
              value = x[offset];
119
          `;
120
};
121

122
const getPadWrap = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
123
  let block = '';
124
  for (let i = inputRank - 1; i >= 0; --i) {
125
    block += `
126
                k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
127
                if (k < 0)  {
128
                  k += i32(${getElementAt('uniforms.x_shape', i, inputRank)}]);
129
                }
130
                if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
131
                  k -= i32(${getElementAt('uniforms.x_shape', i, inputRank)});
132
                }
133
                offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
134
            `;
135
  }
136

137
  return `
138
              var offset = 0;
139
              var k = 0;
140
              ${block}
141
              value = x[offset];
142
          `;
143
};
144

145
const getPadSnippet = (output: IndicesHelper, inputRank: number, attributes: PadAttributes): string => {
146
  switch (attributes.mode) {
147
    case 0:
148
      return getPadConstant(output, inputRank, attributes.pads.length);
149
    case 1:
150
      return getPadReflect(output, inputRank, attributes.pads.length);
151
    case 2:
152
      return getPadEdge(output, inputRank, attributes.pads.length);
153
    case 3:
154
      return getPadWrap(output, inputRank, attributes.pads.length);
155
    default:
156
      throw new Error('Invalid mode');
157
  }
158
};
159

160
const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => {
161
  const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads);
162
  const inputDims = inputs[0].dims;
163
  const outputSize = ShapeUtil.size(outputShape);
164
  const programUniforms: ProgramUniform[] = [
165
    { type: DataType.uint32, data: outputSize },
166
    { type: DataType.int32, data: attributes.pads },
167
  ];
168

169
  const isValueFromInput = inputs.length >= 3 && inputs[2].data;
170
  if (attributes.mode === 0) {
171
    programUniforms.push({ type: isValueFromInput ? inputs[2].dataType : DataType.float, data: attributes.value });
172
  }
173

174
  programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape));
175
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
176

177
  const getShaderSource = (shaderHelper: ShaderHelper) => {
178
    const output = outputVariable('output', inputs[0].dataType, outputShape.length);
179
    const input = inputVariable('x', inputs[0].dataType, inputDims.length);
180
    const dataType = input.type.value;
181
    const padSnippet = getPadSnippet(output, inputDims.length, attributes);
182
    const uniforms: UniformsArrayType = [
183
      { name: 'output_size', type: 'u32' },
184
      { name: 'pads', type: 'i32', length: attributes.pads.length },
185
    ];
186
    if (attributes.mode === 0) {
187
      uniforms.push({ name: 'constant_value', type: (isValueFromInput ? dataType : 'f32') as UniformDataElementType });
188
    }
189

190
    return `
191
            ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
192
            ${shaderHelper.mainStart()}
193
            ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
194

195
            let indices = ${output.offsetToIndices('global_idx')};
196

197
            var value = ${dataType}(0);
198
            ${padSnippet}
199
            output[global_idx] = value;
200
        }`;
201
  };
202

203
  return {
204
    name: 'Pad',
205
    shaderCache: { hint: `${attributes.mode}${isValueFromInput}`, inputDependencies },
206
    getRunData: () => ({
207
      outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
208
      dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
209
      programUniforms,
210
    }),
211
    getShaderSource,
212
  };
213
};
214

215
const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => {
216
  if (inputs.length > 1) {
217
    const bigInt64Pads = inputs[1].getBigInt64Array();
218
    const value =
219
      inputs.length >= 3 && inputs[2].data
220
        ? inputs[2].dataType === DataType.float16
221
          ? inputs[2].getUint16Array()[0]
222
          : inputs[2].getFloat32Array()[0]
223
        : 0.0;
224

225
    const inputRank = inputs[0].dims.length;
226
    const updatePads = new Int32Array(2 * inputRank).fill(0);
227
    if (inputs.length >= 4) {
228
      const axes = inputs[3].getBigInt64Array();
229
      for (let i = 0; i < axes.length; i++) {
230
        updatePads[Number(axes[i])] = Number(bigInt64Pads[i]);
231
        updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]);
232
      }
233
    } else {
234
      bigInt64Pads.forEach((v, i) => (updatePads[Number(i)] = Number(v)));
235
    }
236

237
    const pads: number[] = [];
238
    updatePads.forEach((v) => pads.push(v));
239

240
    return { mode: attributes.mode, value, pads };
241
  } else {
242
    return attributes;
243
  }
244
};
245

246
export const pad = (context: ComputeContext, attributes: PadAttributes): void => {
247
  validateInputs(context.inputs);
248
  const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes);
249
  context.compute(createPadProgramInfo(context.inputs, updatedAttributes), { inputs: [0] });
250
};
251

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.