onnxruntime

Форк
0
/
glsl-shape-utils-lib.ts 
168 строк · 5.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions';
5

6
/**
7
 * GLSL Library responsible for data types and routines for manipulating
8
 * coordinates and mapping to/from tensor indices
9
 */
10
export class ShapeUtilsGlslLib extends GlslLib {
11
  constructor(context: GlslContext) {
12
    super(context);
13
  }
14
  getFunctions(): { [name: string]: GlslLibRoutine } {
15
    return {
16
      ...this.bcastIndex(),
17
      ...this.bcastMatmulIndex(),
18
      ...this.offsetToIndices(),
19
      ...this.indicesToOffset(),
20
      ...this.incrementIndices(),
21
    };
22
  }
23
  getCustomTypes() {
24
    return {};
25
  }
26
  protected bcastIndex(): { [name: string]: GlslLibRoutine } {
27
    const outputRank = this.context.outputTextureLayout.shape.length;
28
    const result: { [name: string]: GlslLibRoutine } = {};
29
    this.context.programInfo.inputNames.forEach((name, i) => {
30
      const shape = this.context.inputTextureLayouts[i].unpackedShape;
31
      if (shape.length <= outputRank) {
32
        const rank = shape.length;
33
        const dimOffset = outputRank - rank;
34
        const funcName = `bcastIndices_${name}`;
35
        let block = '';
36
        for (let i = 0; i < rank; ++i) {
37
          block += `
38
          realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) );
39
          `;
40
        }
41
        const body = `
42
        void ${funcName} (int bcastedIndices[${outputRank}], out int realIndices[${rank}]) {
43
          ${block}
44
        }
45
        `;
46
        result[funcName] = new GlslLibRoutine(body);
47
      }
48
    });
49
    return result;
50
  }
51
  protected bcastMatmulIndex(): { [name: string]: GlslLibRoutine } {
52
    const outputRank = this.context.outputTextureLayout.shape.length;
53
    const result: { [name: string]: GlslLibRoutine } = {};
54
    this.context.programInfo.inputNames.forEach((name, i) => {
55
      const shape = this.context.inputTextureLayouts[i].shape;
56
      if (!(shape.length < 2 || shape.length > outputRank)) {
57
        const rank = shape.length;
58
        const dimOffset = outputRank - rank;
59
        const funcName = `bcastMatmulIndices_${name}`;
60
        let block = '';
61
        for (let i = 0; i < rank - 2; ++i) {
62
          block += `
63
          realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) );
64
          `;
65
        }
66
        const body = `
67
        void ${funcName}(int bcastedIndices[${outputRank}], out int realIndices[${rank}]) {
68
          ${block}
69
          realIndices[${rank - 1}] = bcastedIndices[${outputRank - 1}];
70
          realIndices[${rank - 2}] = bcastedIndices[${outputRank - 2}];
71
        }
72
        `;
73
        result[funcName] = new GlslLibRoutine(body);
74
      }
75
    });
76
    return result;
77
  }
78
  protected indicesToOffset(): { [name: string]: GlslLibRoutine } {
79
    const result: { [name: string]: GlslLibRoutine } = {};
80
    this.context.programInfo.inputNames.forEach((name, i) => {
81
      const shape = this.context.inputTextureLayouts[i].shape;
82
      const strides = this.context.inputTextureLayouts[i].strides;
83
      const rank = shape.length;
84
      let funcName = `indicesToOffset_${name}`;
85
      result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides));
86
      funcName = `indicesToOffset_${name}_T`;
87
      result[funcName] = new GlslLibRoutine(
88
        ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides.slice().reverse()),
89
      );
90
    });
91
    return result;
92
  }
93
  static indexToOffsetSingle(name: string, rank: number, strides: readonly number[]): string {
94
    let block = '';
95
    for (let i = rank - 1; i >= 0; --i) {
96
      block += `
97
        offset += indices[${i}] * ${strides[i]};
98
        `;
99
    }
100
    return `
101
      int ${name}(int indices[${rank}]) {
102
        int offset = 0;
103
        ${block}
104
        return offset;
105
      }
106
      `;
107
  }
108
  protected offsetToIndices(): { [name: string]: GlslLibRoutine } {
109
    const result: { [name: string]: GlslLibRoutine } = {};
110
    this.context.programInfo.inputNames.forEach((name, i) => {
111
      const shape = this.context.inputTextureLayouts[i].shape;
112
      const strides = this.context.inputTextureLayouts[i].strides;
113
      const rank = shape.length;
114
      let funcName = `offsetToIndices_${name}`;
115
      result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides));
116
      funcName = `offsetToIndices_${name}_T`;
117
      result[funcName] = new GlslLibRoutine(
118
        ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides.slice().reverse()),
119
      );
120
    });
121
    return result;
122
  }
123
  static offsetToIndicesSingle(name: string, rank: number, strides: readonly number[]): string {
124
    const stridesBlock = [];
125
    for (let i = 0; i < rank - 1; ++i) {
126
      stridesBlock.push(`
127
      indices[${i}] = offset / ${strides[i]};`);
128
      stridesBlock.push(`
129
        offset -= indices[${i}] * ${strides[i]};`);
130
    }
131
    stridesBlock.push(`
132
      indices[${rank - 1}] = offset;`);
133
    return `
134
      void ${name}(int offset, out int indices[${rank}]) {
135
        ${stridesBlock.join('')}
136
      }
137
      `;
138
  }
139
  protected incrementIndices(): { [name: string]: GlslLibRoutine } {
140
    const result: { [name: string]: GlslLibRoutine } = {};
141
    this.context.programInfo.inputNames.forEach((name, i) => {
142
      const shape = this.context.inputTextureLayouts[i].shape;
143
      const rank = shape.length;
144
      const funcName = `incrementIndices_${name}`;
145
      let shapeInit = '';
146
      for (let i = 0; i < rank; ++i) {
147
        shapeInit += `
148
        shape[${i}] = ${shape[i]};`;
149
      }
150
      const body = `
151
        void ${funcName}(int axis, out int indices[${rank}]) {
152
          int shape[${rank}];
153
          ${shapeInit};
154
          for(int i = ${rank} -1 ; i >= 0; --i) {
155
            if(i > axis) continue;
156
            indices[i] += 1;
157
            if(indices[i] < shape[i]) {
158
              break;
159
            }
160
            indices[i] = 0;
161
          }
162
        }
163
        `;
164
      result[funcName] = new GlslLibRoutine(body);
165
    });
166
    return result;
167
  }
168
}
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.