1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
// TODO: this is the same naive implementation we use for reduce that has
5
// performance limitations when the reduced axis is long. Need to add
6
// a optimized codepath for this.
8
import { DataType } from '../../../wasm-common';
9
import { TensorView } from '../../tensor-view';
10
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
11
import { ComputeContext } from '../types';
13
import { createReduceProgramInfo, ReduceOp } from './reduce';
15
const validateInputs = (inputs: readonly TensorView[]): void => {
16
if (!inputs || inputs.length === 0 || inputs.length > 2) {
17
throw new Error('ArgMinMaxOp op requires 1 or 2 inputs.');
19
if (inputs[0].dataType !== DataType.float) {
20
throw new Error('Invalid input type.');
24
export interface ArgMinMaxAttributes extends AttributeWithCacheKey {
27
selectLastIndex: number;
30
export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
31
validateInputs(context.inputs);
32
const argMinMaxOp: ReduceOp = (input, output, axes) => {
34
for (let k = 0; k < input.rank; k++) {
35
if (axes.indexOf(k) >= 0 || axes.length === 0) {
36
idxZero.push(`input_indices[${k}] = 0;`); // first element
40
`${idxZero.join('\n')}`,
41
`var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
42
`if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) {
43
value = ${input.getByIndices('input_indices')};
44
best_index = i32(last_index);
47
output.setByOffset('global_idx', 'best_index'),
52
createReduceProgramInfo(
54
{ hint: attributes.cacheKey, inputDependencies: ['rank'] },
65
export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
66
validateInputs(context.inputs);
67
const argMinMaxOp: ReduceOp = (input, output, axes) => {
69
for (let k = 0; k < input.rank; k++) {
70
if (axes.indexOf(k) >= 0 || axes.length === 0) {
71
idxZero.push(`input_indices[${k}] = 0;`); // first element
75
`${idxZero.join('\n')}`,
76
`var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
77
`if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) {
78
value = ${input.getByIndices('input_indices')};
79
best_index = i32(last_index);
82
output.setByOffset('global_idx', 'best_index'),
87
createReduceProgramInfo(
89
{ hint: attributes.cacheKey, inputDependencies: ['rank'] },
100
export const parseArgMinMaxAttributes = (attributes: Record<string, unknown>): ArgMinMaxAttributes =>
101
createAttributeWithCacheKey(attributes as Omit<ArgMinMaxAttributes, keyof AttributeWithCacheKey>);