4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
10
createTensorShapeVariables,
16
UniformDataElementType,
20
interface PadAttributes {
22
readonly mode: number;
23
readonly value: number;
24
readonly pads: number[];
27
const validateInputs = (inputs: readonly TensorView[]): void => {
28
if (!inputs || inputs.length < 1) {
29
throw new Error('Too few inputs');
31
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) {
32
throw new Error('Input type must be float or float16.');
35
if (inputs.length >= 2) {
36
let validPads = inputs[0].dims.length * 2 === inputs[1].dims[0];
37
if (inputs.length === 4) {
38
validPads = inputs[3].dims[0] * 2 === inputs[1].dims[0];
41
throw new Error('The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].');
46
const getPadConstant = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
48
for (let i = inputRank - 1; i >= 0; --i) {
50
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
54
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
57
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
62
value = ${output.type.value}(uniforms.constant_value);
63
for (var i = 0; i < 1; i++) {
72
const getPadReflect = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
74
for (let i = inputRank - 1; i >= 0; --i) {
76
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
81
let _2n_1 = 2 * (i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1);
83
if(k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
87
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
99
const getPadEdge = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
101
for (let i = inputRank - 1; i >= 0; --i) {
103
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
107
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
108
k = i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1;
110
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
122
const getPadWrap = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
124
for (let i = inputRank - 1; i >= 0; --i) {
126
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
128
k += i32(${getElementAt('uniforms.x_shape', i, inputRank)}]);
130
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
131
k -= i32(${getElementAt('uniforms.x_shape', i, inputRank)});
133
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
145
const getPadSnippet = (output: IndicesHelper, inputRank: number, attributes: PadAttributes): string => {
146
switch (attributes.mode) {
148
return getPadConstant(output, inputRank, attributes.pads.length);
150
return getPadReflect(output, inputRank, attributes.pads.length);
152
return getPadEdge(output, inputRank, attributes.pads.length);
154
return getPadWrap(output, inputRank, attributes.pads.length);
156
throw new Error('Invalid mode');
160
const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => {
161
const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads);
162
const inputDims = inputs[0].dims;
163
const outputSize = ShapeUtil.size(outputShape);
164
const programUniforms: ProgramUniform[] = [
165
{ type: DataType.uint32, data: outputSize },
166
{ type: DataType.int32, data: attributes.pads },
169
const isValueFromInput = inputs.length >= 3 && inputs[2].data;
170
if (attributes.mode === 0) {
171
programUniforms.push({ type: isValueFromInput ? inputs[2].dataType : DataType.float, data: attributes.value });
174
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape));
175
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
177
const getShaderSource = (shaderHelper: ShaderHelper) => {
178
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
179
const input = inputVariable('x', inputs[0].dataType, inputDims.length);
180
const dataType = input.type.value;
181
const padSnippet = getPadSnippet(output, inputDims.length, attributes);
182
const uniforms: UniformsArrayType = [
183
{ name: 'output_size', type: 'u32' },
184
{ name: 'pads', type: 'i32', length: attributes.pads.length },
186
if (attributes.mode === 0) {
187
uniforms.push({ name: 'constant_value', type: (isValueFromInput ? dataType : 'f32') as UniformDataElementType });
191
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
192
${shaderHelper.mainStart()}
193
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
195
let indices = ${output.offsetToIndices('global_idx')};
197
var value = ${dataType}(0);
199
output[global_idx] = value;
205
shaderCache: { hint: `${attributes.mode}${isValueFromInput}`, inputDependencies },
207
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
208
dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 ) },
215
const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => {
216
if (inputs.length > 1) {
217
const bigInt64Pads = inputs[1].getBigInt64Array();
219
inputs.length >= 3 && inputs[2].data
220
? inputs[2].dataType === DataType.float16
221
? inputs[2].getUint16Array()[0]
222
: inputs[2].getFloat32Array()[0]
225
const inputRank = inputs[0].dims.length;
226
const updatePads = new Int32Array(2 * inputRank).fill(0);
227
if (inputs.length >= 4) {
228
const axes = inputs[3].getBigInt64Array();
229
for (let i = 0; i < axes.length; i++) {
230
updatePads[Number(axes[i])] = Number(bigInt64Pads[i]);
231
updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]);
234
bigInt64Pads.forEach((v, i) => (updatePads[Number(i)] = Number(v)));
237
const pads: number[] = [];
238
updatePads.forEach((v) => pads.push(v));
240
return { mode: attributes.mode, value, pads };
246
export const pad = (context: ComputeContext, attributes: PadAttributes): void => {
247
validateInputs(context.inputs);
248
const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes);
249
context.compute(createPadProgramInfo(context.inputs, updatedAttributes), { inputs: [0] });