onnxruntime

Форк
0
/
texture-layout-strategy.ts 
233 строки · 8.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Logger } from '../../instrument';
5
import { assert } from '../../util';
6

7
/** Layout preferences */
8
export interface WidthHeightPrefs {
9
  breakAxis?: number;
10
  isPacked?: boolean;
11
  reverseWH?: boolean;
12
}
13
/**
14
 * TextureLayoutStrategy is an abstraction for different plans
15
 * for mapping n-dimensional arrays to 2D textures (and back)
16
 */
17
export interface TextureLayoutStrategy {
18
  computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number];
19
}
20

21
/**
22
 * This strategy try to find the minimal max(W,H) that fulfills (W * H == totalSize)
23
 */
24
export class AlwaysKeepOriginalSizeStrategy implements TextureLayoutStrategy {
25
  constructor(public maxTextureSize: number) {}
26
  computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
27
    // scalar tensor
28
    if (shape.length === 0) {
29
      return [1, 1];
30
    }
31
    const maxTextureSize = this.maxTextureSize;
32
    if (prefs && prefs.breakAxis !== undefined) {
33
      // check to see if dims fit
34
      const wsize = prefs.breakAxis >= shape.length ? 1 : shape.slice(prefs.breakAxis).reduce((a, b) => a * b);
35
      const hsize = prefs.breakAxis <= 0 ? 1 : shape.slice(0, prefs.breakAxis).reduce((a, b) => a * b);
36
      if (wsize > maxTextureSize || hsize > maxTextureSize) {
37
        // ignore preferences
38
        // continue with default layout
39
        Logger.verbose(
40
          'TextureLayout',
41
          `Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`,
42
        );
43
      } else {
44
        return [wsize, hsize];
45
      }
46
    }
47
    const totalSize = shape.reduce((a, b) => a * b);
48

49
    let width = Math.floor(Math.sqrt(totalSize));
50

51
    for (; width < maxTextureSize && width < totalSize; width++) {
52
      if (totalSize % width === 0) {
53
        break;
54
      }
55
    }
56

57
    if (width >= maxTextureSize || totalSize % width !== 0) {
58
      throw new Error(`The given dimensions are outside this GPU's boundaries: ${shape}`);
59
    }
60
    return [width, totalSize / width];
61
  }
62
}
63

64
export class PreferLogicalStrategy implements TextureLayoutStrategy {
65
  constructor(public maxTextureSize: number) {}
66
  computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
67
    const wh = this.computeTexture(shape, prefs);
68
    if (prefs && prefs.isPacked) {
69
      wh[0] /= 2;
70
      wh[1] /= 2;
71
    }
72
    if (prefs && prefs.reverseWH) {
73
      return [wh[1], wh[0]];
74
    }
75
    return wh;
76
  }
77

78
  computeTexture(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
79
    const isPacked = prefs && prefs.isPacked;
80
    // scalar tensor
81
    if (shape.length === 0) {
82
      return isPacked ? [2, 2] : [1, 1];
83
    }
84
    let maxTextureSize = this.maxTextureSize;
85
    if (prefs && prefs.breakAxis !== undefined) {
86
      // check to see if dims fit
87
      const wsize = prefs.breakAxis >= shape.length ? 1 : shape.slice(prefs.breakAxis).reduce((a, b) => a * b);
88
      const hsize = prefs.breakAxis <= 0 ? 1 : shape.slice(0, prefs.breakAxis).reduce((a, b) => a * b);
89
      if (wsize > maxTextureSize || hsize > maxTextureSize) {
90
        // ignore preferences
91
        // continue with default layout
92
        Logger.verbose(
93
          'TextureLayout',
94
          `Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`,
95
        );
96
      } else {
97
        return [wsize, hsize];
98
      }
99
    }
100
    let logShape = shape.slice(0);
101
    if (isPacked) {
102
      maxTextureSize = maxTextureSize * 2;
103

104
      // This logic ensures we accurately count the number of packed texels needed
105
      // to accommodate the tensor. We can only pack values in the same texel if
106
      // they are from adjacent pairs of rows/cols within the same batch. So if a
107
      // tensor has 3 rows, we pretend it has 4 rows in order to account for the
108
      // fact that the texels containing the third row are half empty.
109
      logShape = logShape.map((_d, i) =>
110
        i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i],
111
      );
112

113
      // Packed texture height is at least 2 (the channel height of a single
114
      // texel).
115
      if (logShape.length === 1) {
116
        logShape = [2, logShape[0]];
117
      }
118
    }
119

120
    // If logical shape is 2, we don't squeeze, since we want to match physical.
121
    if (logShape.length !== 2) {
122
      const squeezeResult = squeezeShape(logShape);
123
      logShape = squeezeResult.newShape;
124
    }
125

126
    const size = sizeFromShape(logShape);
127
    if (logShape.length <= 1 && size <= maxTextureSize) {
128
      return [1, size];
129
    } else if (logShape.length === 2 && logShape[0] <= maxTextureSize && logShape[1] <= maxTextureSize) {
130
      return logShape as [number, number];
131
    } else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTextureSize && logShape[2] <= maxTextureSize) {
132
      return [logShape[0] * logShape[1], logShape[2]];
133
    } else if (logShape.length === 3 && logShape[0] <= maxTextureSize && logShape[1] * logShape[2] <= maxTextureSize) {
134
      return [logShape[0], logShape[1] * logShape[2]];
135
    } else if (
136
      logShape.length === 4 &&
137
      logShape[0] * logShape[1] * logShape[2] <= maxTextureSize &&
138
      logShape[3] <= maxTextureSize
139
    ) {
140
      return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
141
    } else if (
142
      logShape.length === 4 &&
143
      logShape[0] <= maxTextureSize &&
144
      logShape[1] * logShape[2] * logShape[3] <= maxTextureSize
145
    ) {
146
      return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
147
    } else {
148
      if (isPacked) {
149
        // For packed textures size equals the number of channels required to
150
        // accommodate the texture data. However in order to squarify such that
151
        // inner dimensions stay even, we rewrite size to equal the number of
152
        // texels. Then in the return statement we rehydrate the squarified
153
        // dimensions to channel units.
154
        return sizeToSquarishShape(size / 4).map((d) => d * 2) as [number, number];
155
      }
156
      return sizeToSquarishShape(size);
157
    }
158
  }
159
}
160

161
export function squeezeShape(shape: number[], axis?: number[]): { newShape: number[]; keptDims: number[] } {
162
  const newShape: number[] = [];
163
  const keptDims: number[] = [];
164
  const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
165
  const axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort();
166
  let j = 0;
167
  for (let i = 0; i < shape.length; ++i) {
168
    if (axes != null) {
169
      if (axes[j] === i && shape[i] !== 1) {
170
        throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
171
      }
172
      if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
173
        newShape.push(shape[i]);
174
        keptDims.push(i);
175
      }
176
      if (axes[j] <= i) {
177
        j++;
178
      }
179
    }
180
    if (shape[i] !== 1) {
181
      newShape.push(shape[i]);
182
      keptDims.push(i);
183
    }
184
  }
185
  return { newShape, keptDims };
186
}
187

188
export function parseAxisParam(axis: number | number[], shape: number[]): number[] {
189
  const rank = shape.length;
190

191
  // Normalize input
192
  axis = axis == null ? shape.map((_s, i) => i) : ([] as number[]).concat(axis);
193

194
  // Check for valid range
195
  assert(
196
    axis.every((ax) => ax >= -rank && ax < rank),
197
    () => `All values in axis param must be in range [-${rank}, ${rank}) but ` + `got axis ${axis}`,
198
  );
199

200
  // Check for only integers
201
  assert(axis.every(isInt), () => 'All values in axis param must be integers but ' + `got axis ${axis}`);
202

203
  // Handle negative axis.
204
  return axis.map((a) => (a < 0 ? rank + a : a));
205
}
206
export function isInt(a: number): boolean {
207
  return a % 1 === 0;
208
}
209
export function sizeFromShape(shape: number[]): number {
210
  if (shape.length === 0) {
211
    // Scalar.
212
    return 1;
213
  }
214
  let size = shape[0];
215
  for (let i = 1; i < shape.length; i++) {
216
    size *= shape[i];
217
  }
218
  return size;
219
}
220
export function getRowsCols(shape: number[]): [number, number] {
221
  if (shape.length === 0) {
222
    throw Error('Cannot get rows and columns of an empty shape array.');
223
  }
224

225
  return [shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
226
}
227
export function sizeToSquarishShape(size: number): [number, number] {
228
  const width = Math.ceil(Math.sqrt(size));
229
  return [width, Math.ceil(size / width)];
230
}
231
export function getBatchDim(shape: number[], dimsToSkip = 2): number {
232
  return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
233
}
234

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.