1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { getGlsl, Glsl } from '../glsl-source';
10
import { WebGLInferenceHandler } from '../inference-handler';
11
import { ProgramInfo, TextureType } from '../types';
13
export interface PadAttributes extends AttributeWithCacheKey {
14
readonly mode: string;
15
readonly pads: number[];
16
readonly value: number;
19
const padProgramMetadata = {
22
inputTypes: [TextureType.unpacked],
25
export const padV2: OperatorImplementation<PadAttributes> = (
26
inferenceHandler: WebGLInferenceHandler,
28
attributes: PadAttributes,
30
validateInputsV2(inputs);
31
const output = inferenceHandler.run(
33
...padProgramMetadata,
34
cacheHint: attributes.cacheKey,
35
get: () => createPadProgramInfo(inferenceHandler, inputs[0], attributes),
42
export const parsePadAttributesV2: OperatorInitialization<PadAttributes> = (node: Graph.Node): PadAttributes => {
43
const mode = node.attributes.getString('mode', 'constant');
44
const value = node.attributes.getFloat('value', 0.0);
45
const pads = node.attributes.getInts('pads');
46
return createAttributeWithCacheKey({ mode, value, pads });
49
export const padV11: OperatorImplementation<string> = (
50
inferenceHandler: WebGLInferenceHandler,
54
validateInputsV11(inputs);
55
const attrubutes = generatePadAttributesFromInputs(inferenceHandler, inputs, mode);
56
return padV2(inferenceHandler, [inputs[0]], attrubutes);
59
export const parsePadAttributesV11: OperatorInitialization<string> = (node: Graph.Node): string =>
60
node.attributes.getString('mode', 'constant');
62
const generatePadAttributesFromInputs = (
63
inferenceHandler: WebGLInferenceHandler,
68
!inferenceHandler.session.isInitializer(inputs[1].dataId) ||
69
(inputs.length >= 3 && !inferenceHandler.session.isInitializer(inputs[2].dataId))
71
throw new Error('dynamic pad attributes are not allowed');
74
const pads = Array.from(inputs[1].integerData);
75
const value = inputs.length >= 3 ? inputs[2].floatData[0] : 0.0;
77
return createAttributeWithCacheKey({ mode, pads, value });
80
const createPadProgramInfo = (
81
inferenceHandler: WebGLInferenceHandler,
83
attributes: PadAttributes,
85
const outputShape = ShapeUtil.padShape(input.dims.slice(), attributes.pads);
86
const rank = outputShape.length;
87
const padFunction = getPadFunction(inferenceHandler, input, attributes);
88
const shaderSource = `
90
float process(int[${rank}] indices) {
96
inputTypes: [TextureType.unpacked],
97
output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
102
const validateInputsV2 = (inputs: Tensor[]): void => {
103
if (!inputs || inputs.length !== 1) {
104
throw new Error('Pad requires 1 input');
106
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
107
throw new Error('Invalid input type.');
111
const validateInputsV11 = (inputs: Tensor[]): void => {
112
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
113
throw new Error('Pad requires 2 or 3 inputs');
115
if (inputs[1].type !== 'int32') {
116
throw new Error('Invalid input type.');
118
if (inputs.length >= 3 && inputs[2].type === 'string') {
119
throw new Error('Invalid input type.');
123
const getPadFunction = (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: PadAttributes): string => {
124
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
125
const [width, height] = inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
126
const strides = ShapeUtil.computeStrides(input.dims);
128
switch (attributes.mode) {
130
return getPadConstant(glsl, input.dims, strides, width, height, attributes.pads, attributes.value);
132
return getPadReflect(glsl, input.dims, strides, width, height, attributes.pads);
134
return getPadEdge(glsl, input.dims, strides, width, height, attributes.pads);
136
throw new Error('Invalid mode');
140
const getPadConstant = (
142
shape: readonly number[],
143
strides: readonly number[],
149
const rank = shape.length;
151
for (let i = rank - 1; i >= 0; --i) {
153
k = m[${i}] - ${pads[i]};
154
if (k < 0) return constant;
155
if (k >= ${shape[i]}) return constant;
156
offset += k * ${strides[i]};
160
float padA(int m[${rank}]) {
161
const float constant = float(${value});
165
vec2 coords = offsetToCoords(offset, ${width}, ${height});
166
float value = getColorAsFloat(${glsl.texture2D}(A, coords));
172
const getPadReflect = (
174
shape: readonly number[],
175
strides: readonly number[],
180
const rank = shape.length;
183
for (let i = rank - 1; i >= 0; --i) {
185
k = m[${i}] - ${pads[i]};
186
if (k < 0) { k = -k; }
188
const int _2n_1 = ${2 * (shape[i] - 1)};
189
k = int( mod( float(k), float(_2n_1) ) ) ;
190
if(k >= ${shape[i]}) { k = _2n_1 - k; }
192
offset += k * ${strides[i]};
196
float padA(int m[${rank}]) {
200
vec2 coords = offsetToCoords(offset, ${width}, ${height});
201
float value = getColorAsFloat(${glsl.texture2D}(A, coords));
209
shape: readonly number[],
210
strides: readonly number[],
215
const rank = shape.length;
218
for (let i = rank - 1; i >= 0; --i) {
220
k = m[${i}] - ${pads[i]};
222
if (k >= ${shape[i]}) k = ${shape[i] - 1};
223
offset += k * ${strides[i]};
227
float padA(int m[${rank}]) {
231
vec2 coords = offsetToCoords(offset, ${width}, ${height});
232
float value = getColorAsFloat(${glsl.texture2D}(A, coords));