1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
12
export interface SliceAttributes extends AttributeWithCacheKey {
13
readonly axes: number[];
14
readonly ends: number[];
15
readonly starts: number[];
18
const sliceProgramMetadata = {
21
inputTypes: [TextureType.unpacked],
24
export const slice: OperatorImplementation<SliceAttributes> = (
25
inferenceHandler: WebGLInferenceHandler,
27
attributes: SliceAttributes,
29
validateInputs(inputs);
30
const output = inferenceHandler.run(
32
...sliceProgramMetadata,
33
cacheHint: attributes.cacheKey,
34
get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes),
41
export const parseSliceAttributes: OperatorInitialization<SliceAttributes> = (node: Graph.Node): SliceAttributes => {
42
const starts = node.attributes.getInts('starts');
43
const ends = node.attributes.getInts('ends');
44
const axes = node.attributes.getInts('axes', []);
45
return createAttributeWithCacheKey({ starts, ends, axes });
48
const createSliceProgramInfo = (
49
_inferenceHandler: WebGLInferenceHandler,
51
attributes: SliceAttributes,
53
const axes = attributes.axes.length === 0 ? input.dims.slice(0).map((_val, i) => i) : attributes.axes;
54
const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length);
55
const starts = attributes.starts.map((start, i) => {
56
if (start > input.dims[normalizedAxes[i]] - 1) {
57
return input.dims[normalizedAxes[i]];
59
return ShapeUtil.normalizeAxis(start, input.dims[normalizedAxes[i]]);
61
const ends = attributes.ends.map((end, i) => {
62
if (end > input.dims[normalizedAxes[i]] - 1) {
63
return input.dims[normalizedAxes[i]];
65
return ShapeUtil.normalizeAxis(end, input.dims[normalizedAxes[i]]);
68
const outputShape = input.dims.slice();
70
const sliceOps: string[] = [];
71
for (let i = 0; i < normalizedAxes.length; i++) {
72
outputShape[normalizedAxes[i]] = ends[i] - starts[i];
74
sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`);
75
} // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); }
78
const rank = outputShape.length;
79
const shaderSource = `
80
float process(int outputIdx[${rank}]) {
81
${sliceOps.join('\n ')}
85
...sliceProgramMetadata,
86
output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
91
const validateInputs = (inputs: Tensor[]): void => {
92
if (!inputs || inputs.length !== 1) {
93
throw new Error('Slice requires 1 input.');
95
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
96
throw new Error('Invalid input type.');
100
export const sliceV10 = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
101
validateInputsV10(inputs);
102
const attributes = generateSliceAttributesFromInputs(inferenceHandler, inputs);
103
const output = inferenceHandler.run(
105
...sliceProgramMetadata,
106
cacheHint: attributes.cacheKey,
107
get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes),
114
const generateSliceAttributesFromInputs = (
115
inferenceHandler: WebGLInferenceHandler,
117
): SliceAttributes => {
119
!inferenceHandler.session.isInitializer(inputs[1].dataId) ||
120
!inferenceHandler.session.isInitializer(inputs[2].dataId) ||
121
(inputs.length >= 4 && !inferenceHandler.session.isInitializer(inputs[3].dataId)) ||
122
(inputs.length >= 5 && !inferenceHandler.session.isInitializer(inputs[4].dataId))
124
throw new Error('dynamic slice attributes are not allowed');
127
if (inputs.length >= 5 && inputs[4].integerData.some((i: number) => i !== 1)) {
128
throw new Error('currently non-1 steps is not supported for Slice');
131
const starts = Array.from(inputs[1].integerData);
132
const ends = Array.from(inputs[2].integerData);
133
const axes = inputs.length >= 4 ? Array.from(inputs[3].integerData) : [];
134
const cacheKey = `${axes};${starts};${ends}`;
135
return { starts, ends, axes, cacheKey };
138
const validateInputsV10 = (inputs: Tensor[]): void => {
139
if (!inputs || inputs.length < 3 || inputs.length > 5) {
140
throw new Error('Invalid input number.');
142
if (inputs[1].type !== 'int32' || inputs[1].dims.length !== 1) {
143
throw new Error('Invalid input type.');
145
if (inputs[2].type !== 'int32' || inputs[2].dims.length !== 1) {
146
throw new Error('Invalid input type.');
148
if (inputs.length >= 4 && (inputs[3].type !== 'int32' || inputs[3].dims.length !== 1)) {
149
throw new Error('Invalid input type.');
151
if (inputs.length >= 5 && (inputs[4].type !== 'int32' || inputs[4].dims.length !== 1)) {
152
throw new Error('Invalid input type.');